Source code for starclass.RFGCClassifier.RF_GC

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The RF-GC classifier (general random forest).

.. codeauthor:: David Armstrong <d.j.armstrong@warwick.ac.uk>
"""

import os
import copy
import numpy as np
from bottleneck import anynan
from sklearn.ensemble import RandomForestClassifier
from . import RF_GC_featcalc as fc
from .. import BaseClassifier, io
from ..utilities import get_periods
from ..exceptions import UntrainedClassifierError

# Number of frequencies used as features:
NFREQUENCIES = 6

#--------------------------------------------------------------------------------------------------
class Classifier_obj(RandomForestClassifier):
	"""
	Wrapper for sklearn RandomForestClassifier with attached SOM.
	"""
	def __init__(self, n_estimators=1000, max_features=4, min_samples_split=2, random_state=None):
		super().__init__(n_estimators=n_estimators,
			max_features=max_features,
			min_samples_split=min_samples_split,
			class_weight='balanced', max_depth=15,
			random_state=random_state)
		self.trained = False
		self.som = None

#--------------------------------------------------------------------------------------------------
[docs] class RFGCClassifier(BaseClassifier): """ General Random Forest .. codeauthor:: David Armstrong <d.j.armstrong@warwick.ac.uk> """
[docs] def __init__(self, clfile='rfgc_classifier_v01.pickle', somfile='rfgc_som.txt', dimx=1, dimy=400, cardinality=64, n_estimators=1000, max_features=4, min_samples_split=2, *args, **kwargs): """ Initialize the classifier object. Parameters: clfile (str): Filepath to previously pickled Classifier_obj. somfile (str): Filepath to trained SOM saved using fc.kohonenSave featfile (str): Filepath to pre-calculated features, if available. dimx (int): dimension 1 of SOM in somfile, if given dimy (int): dimension 2 of SOM in somfile, if given cardinality (int): N bins per SOM pixel in somfile, if given n_estimators (int): number of trees in forest max_features (int): see sklearn.RandomForestClassifier min_samples_split (int): see sklearn.RandomForestClassifier """ # Initialise parent super().__init__(*args, **kwargs) self.classifier = None if somfile is not None: self.somfile = os.path.join(self.data_dir, somfile) else: self.somfile = None if clfile is not None: self.clfile = os.path.join(self.data_dir, clfile) else: self.clfile = None if self.clfile is not None: if os.path.exists(self.clfile): # load pre-trained classifier self.load(self.clfile, self.somfile) if self.classifier is None: self.classifier = Classifier_obj(n_estimators=n_estimators, max_features=max_features, min_samples_split=min_samples_split, random_state=self.random_state) if self.classifier.som is None and self.somfile is not None: # load som if os.path.exists(self.somfile): self.classifier.som = fc.loadSOM(self.somfile, random_seed=self.random_seed) # List of feature names used by the classifier: self.features_names = ['EBperiod'] self.features_names += ['p' + str(i+1) for i in range(1, NFREQUENCIES)] self.features_names += [ 'ampratio21', 'ampratio31', 'phasediff21', 'phasediff31', 'SOM_map', 'SOM_range', 'p2p_98_phasefold', 'p2p_mean_phasefold', 'p2p_98_lc', 'p2p_mean_lc', 'psi', 'zc', 'Fp07', 'Fp7', 'Fp20', 'Fp50' ] if self.linfit: self.features_names.append('detrend_coeff_norm') # Link to the internal RandomForestClassifier classifier model, # which can be used for calculating feature importances: self._classifier_model = self.classifier
#----------------------------------------------------------------------------------------------
[docs] def save(self, outfile, somoutfile='som.txt'): """ Saves the classifier object with pickle. som object saved as this MUST be the one used to train the classifier. """ fc.kohonenSave(self.classifier.som.K, os.path.join(self.data_dir, somoutfile)) # overwrites tempsom = copy.deepcopy(self.classifier.som) self.classifier.som = None io.savePickle(outfile, self.classifier) self.classifier.som = tempsom
#----------------------------------------------------------------------------------------------
[docs] def load(self, infile, somfile=None): """ Loads classifier object. somfile MUST match the som used to train the classifier. """ self.classifier = io.loadPickle(infile) if somfile is not None and os.path.exists(somfile): self.classifier.som = fc.loadSOM(somfile) if self.classifier.som is None: self.classifier.trained = False
#--------------------------------------------------------------------------------------------------
[docs] def featcalc(self, features, total=None, cardinality=64, linflatten=False, recalc=False): """ Calculates features for set features. """ if isinstance(features, dict): # trick for single features features = [features] if total is None: total = len(features) # Loop through the provided features and build feature table: featout = np.empty([total, len(self.features_names)], dtype='float32') for k, obj in enumerate(features): # Load features from the provided (cached) features if they exist: featout[k, :] = [obj.get(key, np.NaN) for key in self.features_names] # If not all features are already populated, we are going to recalculate them all: if recalc or anynan(featout[k, :]): lc = fc.prepLCs(obj['lightcurve'], linflatten=linflatten) periods, n_usedfreqs, usedfreqs = get_periods(obj, NFREQUENCIES, lc.time, ignore_harmonics=True) featout[k, :NFREQUENCIES] = periods EBper = fc.EBperiod(lc.time, lc.flux, periods[0], linflatten=True) featout[k, 0] = EBper # overwrites top period amp21, amp31 = fc.freq_ampratios(obj, n_usedfreqs, usedfreqs) featout[k, NFREQUENCIES] = amp21 featout[k, NFREQUENCIES+1] = amp31 phi21, phi31 = fc.freq_phasediffs(obj, n_usedfreqs, usedfreqs) featout[k, NFREQUENCIES+2] = phi21 featout[k, NFREQUENCIES+3] = phi31 # Self Organising Map featout[k, NFREQUENCIES+4:NFREQUENCIES+6] = fc.SOMloc(self.classifier.som, lc.time, lc.flux, EBper, cardinality) featout[k, NFREQUENCIES+6:NFREQUENCIES+8] = fc.phase_features(lc.time, lc.flux, EBper) featout[k, NFREQUENCIES+8:NFREQUENCIES+10] = fc.p2p_features(lc.flux) # Higher Order Crossings: psi, zc = fc.compute_hocs(lc.time, lc.flux, 5) featout[k, NFREQUENCIES+10:NFREQUENCIES+12] = psi, zc[0] # FliPer: featout[k, NFREQUENCIES+12:NFREQUENCIES+16] = obj['Fp07'], obj['Fp7'], obj['Fp20'], obj['Fp50'] # If we are running with linfit enabled, add an extra feature # which is the absoulte value of the fitted linear trend, divided # with the point-to-point scatter: if self.linfit: slope_feature = np.abs(obj['detrend_coeff'][0]) / obj['ptp'] featout[k, NFREQUENCIES+16] = slope_feature return featout
#----------------------------------------------------------------------------------------------
[docs] def do_classify(self, features, recalc=False): """ Classify a single lightcurve. Parameters: features (dict): Dictionary of features. Returns: dict: Dictionary of stellar classifications. """ if not self.classifier.trained: raise UntrainedClassifierError('Classifier has not been trained. Exiting.') # Assumes that if self.classifier.trained=True, # ...then self.classifier.som is not None self.logger.debug("Calculating features...") featarray = self.featcalc(features, total=1, recalc=recalc) #logger.info("Features calculated.") # Do the magic: classprobs = self.classifier.predict_proba(featarray)[0] self.logger.debug("Classification complete") result = {} for c, cla in enumerate(self.classifier.classes_): key = self.StellarClasses(cla) result[key] = classprobs[c] return result, featarray
#----------------------------------------------------------------------------------------------
[docs] def train(self, tset, savecl=True, recalc=False, overwrite=False): """ Train the classifier. Parameters: tset (``TrainingSet``): labels for training set lightcurves. features (iterable of dict): features, inc lightcurves. savecl (bool, optional): Save classifier? (``overwrite`` or ``recalc`` must be true for an old classifier to be overwritten). overwrite (bool, optional): Reruns SOM. recalc (bool, optional): Recalculates features. """ if self.classifier.trained: return # Check for pre-calculated features fitlabels = self.parse_labels(tset.labels()) self.logger.info('Calculating features...') # Check for pre-calculated som if self.classifier.som is None: self.logger.info("No SOM loaded. Creating new SOM, saving to '%s'.", self.somfile) self.classifier.som = fc.makeSOM(tset.features(), outfile=self.somfile, overwrite=overwrite, random_seed=self.random_seed) self.logger.info('SOM created and saved.') self.logger.info('Calculating/Loading Features.') featarray = self.featcalc(tset.features(), total=len(tset), recalc=recalc) self.logger.info('Features calculated/loaded.') self.classifier.oob_score = True self.classifier.fit(featarray, fitlabels) self.logger.info('Trained. OOB Score = %f', self.classifier.oob_score_) #logger.info([estimator.tree_.max_depth for estimator in self.classifier.estimators_]) self.classifier.oob_score = False self.classifier.trained = True if savecl and self.clfile is not None: if not os.path.exists(self.clfile) or overwrite or recalc: self.logger.info("Saving pickled classifier instance to '%s'", self.clfile) self.logger.info("Saving SOM to '%s'", self.somfile) self.save(self.clfile, self.somfile)
#----------------------------------------------------------------------------------------------
[docs] def loadsom(self, somfile): """ Loads a SOM, if not done at init. """ self.classifier.som = fc.loadSOM(somfile)