Source code for starclass.SortingHatClassifier.Sorting_Hat

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The Sorting-Hat Classifier (Supervised randOm foRest variabiliTy classIfier using high-resolution pHotometry Attributtes in TESS data).

.. codeauthor:: Jeroen Audenaert <jeroen.audenaert@kuleuven.be>
"""

import os
import numpy as np
from bottleneck import anynan
import scipy.stats as stat
from sklearn.ensemble import RandomForestClassifier
from . import Sorting_Hat_featcalc as fc
from .. import BaseClassifier, io
from ..utilities import get_periods
from ..exceptions import UntrainedClassifierError

# Number of frequencies used as features:
NFREQUENCIES = 3

#--------------------------------------------------------------------------------------------------
class Classifier_obj(RandomForestClassifier):
	"""
	Wrapper for sklearn RandomForestClassifier.
	"""
	def __init__(self, n_estimators=1000, max_features='auto', min_samples_split=2, class_weight='balanced', criterion='gini', max_depth=15, random_state=None):
		super().__init__(n_estimators=n_estimators,
			max_features=max_features,
			min_samples_split=min_samples_split,
			class_weight=class_weight, criterion=criterion, max_depth=max_depth,
			random_state=random_state)
		self.trained = False

#--------------------------------------------------------------------------------------------------
[docs] class SortingHatClassifier(BaseClassifier): """ Sorting-Hat Classifier """
[docs] def __init__(self, clfile='sortinghat_classifier_v01.pickle', n_estimators=1000, max_features='auto', min_samples_split=2, *args, **kwargs): """ Initialize the classifier object. Parameters: clfile (str): Filepath to previously pickled Classifier_obj. featfile (str): Filepath to pre-calculated features, if available. n_estimators (int): number of trees in forest max_features (int): see sklearn.RandomForestClassifier min_samples_split (int): see sklearn.RandomForestClassifier """ # Initialise parent super().__init__(*args, **kwargs) self.classifier = None if clfile is not None: self.clfile = os.path.join(self.data_dir, clfile) else: self.clfile = None if self.clfile is not None and os.path.exists(self.clfile): # load pre-trained classifier self.load(self.clfile) self.features_names = ['f' + str(i+1) for i in range(NFREQUENCIES)] self.features_names += [ 'varrat', 'number_significantharmonic', 'skewness', 'flux_ratio', 'diff_entropy_lc', 'diff_entropy_as', 'mse_mean', 'mse_max', 'mse_std', 'mse_power' ] if self.classifier is None: # Create new untrained classifier self.classifier = Classifier_obj( n_estimators=n_estimators, max_features=max_features, min_samples_split=min_samples_split, random_state=self.random_state) # Link to the internal RandomForestClassifier classifier model, # which can be used for calculating feature importances: self._classifier_model = self.classifier
#----------------------------------------------------------------------------------------------
[docs] def save(self, outfile): """ Save the classifier object with pickle. """ io.savePickle(outfile, self.classifier)
#----------------------------------------------------------------------------------------------
[docs] def load(self, infile): """ Load classifier object. """ self.classifier = io.loadPickle(infile)
#----------------------------------------------------------------------------------------------
[docs] def featcalc(self, features, total=None, recalc=False): """ Calculates features for set of lightcurves """ if isinstance(features, dict): # trick for single features features = [features] if total is None: total = len(features) featout = np.empty([total, len(self.features_names)], dtype='float32') for k, obj in enumerate(features): # Load features from the provided (cached) features if they exist: featout[k, :] = [obj.get(key, np.NaN) for key in self.features_names] # If not all features are already populated, we are going to recalculate them all: if recalc or anynan(featout[k, :]): lc = fc.prepLCs(obj['lightcurve'], linflatten=False) periods, _, _ = get_periods(obj, NFREQUENCIES, lc.time, in_days=False) featout[k, :NFREQUENCIES] = periods #EBper = EBperiod(lc.time, lc.flux, periods[0], linflatten=linflatten-1) #featout[k, 0] = EBper # overwrites top period featout[k, NFREQUENCIES:NFREQUENCIES+2] = fc.compute_varrat(obj) #featout[k, NFREQUENCIES+1:NFREQUENCIES+2] = fc.compute_lpf1pa11(obj) featout[k, NFREQUENCIES+2:NFREQUENCIES+3] = stat.skew(lc.flux) featout[k, NFREQUENCIES+3:NFREQUENCIES+4] = fc.compute_flux_ratio(lc.flux) featout[k, NFREQUENCIES+4:NFREQUENCIES+5] = fc.compute_differential_entropy(lc.flux) featout[k, NFREQUENCIES+5:NFREQUENCIES+6] = fc.compute_differential_entropy(obj['powerspectrum'].standard[1]) featout[k, NFREQUENCIES+6:NFREQUENCIES+10] = fc.compute_multiscale_entropy(lc.flux) #featout[k, NFREQUENCIES+10:NFREQUENCIES+11] = fc.compute_max_lyapunov_exponent(lc.flux) return featout
#----------------------------------------------------------------------------------------------
[docs] def do_classify(self, features, recalc=False): """ Classify a single lightcurve. Parameters: features (dict): Dictionary of features. Returns: dict: Dictionary of stellar classifications. """ if not self.classifier.trained: raise UntrainedClassifierError('Classifier has not been trained. Exiting.') # If self.classifier.trained=True, calculate additional features self.logger.debug("Calculating features...") featarray = self.featcalc(features, total=1, recalc=recalc) #self.logger.info("Features calculated.") # Do the magic: classprobs = self.classifier.predict_proba(featarray)[0] self.logger.debug("Classification complete") result = {} for c, cla in enumerate(self.classifier.classes_): key = self.StellarClasses(cla) result[key] = classprobs[c] return result, featarray
#----------------------------------------------------------------------------------------------
[docs] def train(self, tset, savecl=True, recalc=False, overwrite=False): """ Train the classifier. Parameters: labels (ndarray, [n_objects]): labels for training set lightcurves. features (iterable of dict): features, inc lightcurves. savecl: save classifier? (overwrite or recalc must be true for an old classifier to be overwritten) overwrite: reruns SOM recalc: recalculates features """ if self.classifier.trained: return # Check for pre-calculated features fitlabels = self.parse_labels(tset.labels()) self.logger.info('Calculating/Loading Features.') featarray = self.featcalc(tset.features(), total=len(tset), recalc=recalc) self.logger.info('Features calculated/loaded.') self.classifier.oob_score = True self.classifier.fit(featarray, fitlabels) self.logger.info('Trained. OOB Score = %f', self.classifier.oob_score_) #logger.info([estimator.tree_.max_depth for estimator in self.classifier.estimators_]) self.classifier.oob_score = False self.classifier.trained = True if savecl and self.clfile is not None: if not os.path.exists(self.clfile) or overwrite or recalc: self.logger.info("Saving pickled classifier instance to '%s'", self.clfile) self.save(self.clfile)