Source code for starclass.SLOSH.SLOSH

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The SLOSH method for detecting solar-like oscillations (2D deep learning methods).

.. codeauthor:: Marc Hon <mtyh555@uowmail.edu.au>
.. codeauthor:: James Kuszlewicz <kuszlewicz@mps.mpg.de>
"""

import numpy as np
import os
import logging
from tqdm import tqdm
import h5py
import tempfile
import tensorflow
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from sklearn.metrics import classification_report
from . import SLOSH_prepro as preprocessing
from .. import BaseClassifier
from ..exceptions import UntrainedClassifierError

#--------------------------------------------------------------------------------------------------
[docs] class SLOSHClassifier(BaseClassifier): """ Solar-like Oscillation Shape Hunter (SLOSH) Classifier. .. codeauthor:: Marc Hon <mtyh555@uowmail.edu.au> .. codeauthor:: James Kuszlewicz <kuszlewicz@mps.mpg.de> .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk> """
[docs] def __init__(self, clfile='SLOSH_Classifier_Model.hdf5', mc_iterations=10, *args, **kwargs): """ Initialization for the class. :param saved_models: LIST of saved classifier filenames. Supports multi-classifier predictions. """ # Initialize parent: super().__init__(*args, **kwargs) self.classifier_list = [] #self.classifier = None self.mc_iterations = mc_iterations self.num_labels = len(self.StellarClasses) self.features_names = [] # SLOSH have no features as such # Set the global random seeds: np.random.seed(self.random_seed) tensorflow.random.set_seed(self.random_seed) # Find model file if clfile is not None: self.model_file = os.path.join(self.data_dir, clfile) else: self.model_file = None if self.model_file is not None and os.path.exists(self.model_file): self.logger.debug("Loading pre-trained model...") # load pre-trained classifier self.predictable = True self.classifier_list.append(tensorflow.keras.models.load_model(self.model_file)) else: self.logger.info('No saved models provided. Predict functions are disabled.') self.predictable = False
#----------------------------------------------------------------------------------------------
[docs] def do_classify(self, features): """ Prediction for a star, producing output determining if it is a solar-like oscillator. Parameters: features (dict): Dictionary of features. Of particular interest should be the `lightcurve` (``lightkurve.TessLightCurve`` object) and `powerspectum` which contains the lightcurve and power density spectrum respectively. Returns: dict: Dictionary of stellar classifications. """ if not self.predictable: raise UntrainedClassifierError('No saved models provided. Predict functions are disabled.') # Pre-calculated power density spectrum: psd = features['powerspectrum'].standard self.logger.debug('Generating Image...') img_array = preprocessing.generate_single_image(psd[0], psd[1]) img_array = img_array.reshape(1, 128, 128, 1) self.logger.debug('Making Predictions...') pred_array = np.zeros((self.mc_iterations, self.num_labels)) for i in range(self.mc_iterations): pred_array[i, :] = self.classifier_list[0](img_array, training=False) pred = np.mean(pred_array, axis=0) # Convert the integer labels used by SLOSH to StellarClasses again # and put it all together in the result dict: result = {} for k, stcl in enumerate(self.StellarClasses): result[stcl] = pred[k] return result, []
#----------------------------------------------------------------------------------------------
[docs] def train(self, tset): """ Trains a fresh classifier using a default NN architecture and parameters as of the Hon et al. (2018) paper. Parameters: train_folder: The folder where training images are kept. These must be separated into subfolders by the image categories. For example: Train_Folder/1/ - Positive Detections; Train_Folder/0/ - Non-Detections features (iterator of dicts): Iterator of features-dictionaries similar to those in :meth:`do_classify`. labels (iterator of lists): For each feature, provides a list of the assigned known :class:`StellarClasses` identifications. Returns: model: A trained classifier model. """ if self.predictable: return # Settings for progress bar used below: tqdm_settings = { 'disable': None if self.logger.isEnabledFor(logging.INFO) else True } dset_settings = { 'compression': 'lzf', 'shuffle': True, 'fletcher32': True, 'chunks': (128, 128), 'dtype': 'float32' } # Convert classification labels to integers: intlookup = {key.value: value for value, key in enumerate(self.StellarClasses)} intlabels = [intlookup[lbl] for lbl in self.parse_labels(tset.labels())] self.logger.info('Generating Train Images...') if self.features_cache: train_folder = os.path.join(self.features_cache, 'SLOSH_Train_Images') os.makedirs(train_folder, exist_ok=True) else: tmpdir = tempfile.TemporaryDirectory() train_folder = tmpdir.name # Go through the training-set and ensure that all images are created: # Images are stored in a HDF5 file as individual datasets. hdf5_file = os.path.join(train_folder, 'SLOSH_Train_Images.hdf5') datasets = [] with h5py.File(hdf5_file, 'a') as hdf: images = hdf.require_group('images') for feat in tqdm(tset.features(), total=len(tset), **tqdm_settings): dset_name = str(feat['priority']) datasets.append(dset_name) if dset_name not in images: # Power density spectrum from pre-calculated features: psd = feat['powerspectrum'].standard # Generate and save image to file: img = preprocessing.generate_single_image(psd[0], psd[1]) images.create_dataset(dset_name, data=img, **dset_settings) hdf.flush() # Find the level of verbosity to add to tensorflow calls: if self.logger.isEnabledFor(logging.INFO): verbose = 2 else: verbose = 0 # Open the HDF5 file containing the features cache in read-only mode # so it can be passed to the generators: with h5py.File(hdf5_file, 'r') as hdf: train_generator = preprocessing.npy_generator(datasets, intlabels, hdf5_file=hdf, subset='train', random_seed=self.random_seed) valid_generator = preprocessing.npy_generator(datasets, intlabels, hdf5_file=hdf, subset='valid', random_seed=self.random_seed) reduce_lr = ReduceLROnPlateau(factor=0.5, patience=5, verbose=verbose) early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True) checkpoint = ModelCheckpoint(self.model_file, monitor='val_loss', verbose=verbose, save_best_only=True) #class_accuracy = TestCallback(valid_generator, classes=self.StellarClasses) model = preprocessing.default_classifier_model(num_classes=len(self.StellarClasses)) self.logger.info('Training Classifier...') epochs = 50 model.fit(train_generator, epochs=epochs, steps_per_epoch=len(train_generator), validation_data=valid_generator, validation_steps=len(valid_generator), callbacks=[reduce_lr, early_stop, checkpoint], verbose=verbose) # Save the model to file: self.save_model(model, self.model_file)
#----------------------------------------------------------------------------------------------
[docs] def save_model(self, model, model_file): ''' Saves out trained model : param model: trained model : param model_file: Output file name for model :return: None ''' # Save out model model.save(model_file) self.classifier_list = [model] # Set predictable to true so can predict self.predictable = True
#----------------------------------------------------------------------------------------------
[docs] def save(self, outfile): ''' Saves all loaded classifier models. :param outfile: Base output file name :return: None ''' if not self.predictable: raise ValueError('No saved models in memory.') for i in range(len(self.classifier_list)): self.classifier_list[i].save(outfile + '-%s.h5' % i)
#----------------------------------------------------------------------------------------------
[docs] def load(self, infile): ''' Loads a classifier model and adds it to the list of classifiers. :param infile: Path to trained model :return: None ''' self.classifier_list.append(tensorflow.keras.models.load_model(infile)) self.predictable = True
#----------------------------------------------------------------------------------------------
[docs] def clear_model_list(self): ''' Helper function to clear classifiers in the classifier list. :return: None ''' del self.classifier_list[:] self.predictable = False
#-------------------------------------------------------------------------------------------------- class TestCallback(tensorflow.keras.callbacks.Callback): def __init__(self, val_data, classes): self.validation_data = val_data self.batch_size = 32 self.num_classes = len(classes) self.class_names = [cl.name for cl in classes] def on_train_begin(self, logs={}): print(self.validation_data) #self.val_vals = [] def on_epoch_end(self, epoch, logs={}): #batches = len(self.validation_data) #total = batches * self.batch_size val_pred = np.zeros((1, self.num_classes)) val_true = np.zeros((1, self.num_classes)) for xVal, yVal in self.validation_data: val_pred = np.vstack([val_pred, self.model.predict(xVal)]) val_true = np.vstack([val_true, yVal]) val_pred = np.argmax(val_pred[1:,:], axis=1) val_true = np.argmax(val_true[1:,:], axis=1) print(np.shape(val_pred), np.shape(val_true)) #print(np.argmax(val_pred, axis=1)) #print(np.argmax(val_true, axis=1)) print(classification_report(val_true, val_pred, target_names=self.class_names))