Source code for starclass.MetaClassifier.Meta
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The meta-classifier.
.. codeauthor:: James S. Kuszlewicz <kuszlewicz@mps.mpg.de>
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
import os
import itertools
import numpy as np
from bottleneck import allnan, anynan
from sklearn.ensemble import RandomForestClassifier
from .. import BaseClassifier, io
from ..constants import classifier_list
from ..exceptions import UntrainedClassifierError
#--------------------------------------------------------------------------------------------------
class Classifier_obj(RandomForestClassifier):
"""
Wrapper for sklearn RandomForestClassifier.
"""
def __init__(self, n_estimators=100, min_samples_split=2, random_state=None):
super().__init__(
n_estimators=n_estimators,
min_samples_split=min_samples_split,
class_weight='balanced',
max_depth=7,
random_state=random_state
)
self.trained = False
#--------------------------------------------------------------------------------------------------
[docs]
class MetaClassifier(BaseClassifier):
"""
The meta-classifier.
Attributes:
clfile (str): Path to the file where the classifier is saved.
classifier (:class:`Classifier_obj`): Actual classifier object.
features_used (list): List of features used for training.
.. codeauthor:: James S. Kuszlewicz <kuszlewicz@mps.mpg.de>
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
[docs]
def __init__(self, clfile='meta_classifier.pickle', *args, **kwargs):
"""
Initialize the classifier object.
Parameters:
clfile (str): Filepath to previously pickled Classifier_obj
"""
# Initialize parent
super().__init__(*args, **kwargs)
self.clfile = None
self.classifier = None
self.features_used = None
if clfile is not None:
self.clfile = os.path.join(self.data_dir, clfile)
# Check if pre-trained classifier exists
if self.clfile is not None and os.path.exists(self.clfile):
# Load pre-trained classifier
self.load(self.clfile)
# Set up classifier
if self.classifier is None:
self.classifier = Classifier_obj(random_state=self.random_state)
# Link to the internal RandomForestClassifier classifier model,
# which can be used for calculating feature importances:
self._classifier_model = self.classifier
#----------------------------------------------------------------------------------------------
[docs]
def save(self, outfile):
"""
Saves the classifier object with pickle.
"""
io.savePickle(outfile, [self.classifier, self.features_used])
#----------------------------------------------------------------------------------------------
[docs]
def load(self, infile):
"""
Loads classifier object.
"""
# Load the pickle file:
self.classifier, self.features_used = io.loadPickle(infile)
# Extract the features names based on the loaded classifier:
self.features_names = [f'{classifier:s}_{stcl.name:s}' for classifier, stcl in self.features_used]
self.logger.debug("Feature names: %s", self.features_names)
#----------------------------------------------------------------------------------------------
[docs]
def build_features_table(self, features, total=None):
"""
Build table of features.
Parameters:
features (iterable): Features to build table from.
total (int, optional): Number of features in ``features``. If not provided,
the length of ``features`` is found using :func:`len`.
Returns:
ndarray: Two dimensional float32 ndarray with probabilities from all classifiers.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
if total is None:
total = len(features)
featarray = np.full((total, len(self.features_used)), np.NaN, dtype='float32')
for k, feat in enumerate(features):
tab = feat['other_classifiers']
for j, (classifier, stcl) in enumerate(self.features_used):
indx = (tab['classifier'] == classifier) & (tab['class'] == stcl)
if any(indx):
featarray[k, j] = tab['prob'][indx]
return featarray
#----------------------------------------------------------------------------------------------
[docs]
def do_classify(self, features):
"""
Classify a single lightcurve.
Parameters:
features (dict): Dictionary of features.
Returns:
dict: Dictionary of stellar classifications.
Raises:
UntrainedClassifierError: If classifier has not been trained.
ValueError: If any features are NaN.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
if not self.classifier.trained:
raise UntrainedClassifierError('Classifier has not been trained. Exiting.')
# Build features array from the probabilities from the other classifiers:
# TODO: What about NaN values?
self.logger.debug("Importing features...")
featarray = self.build_features_table([features], total=1)
if anynan(featarray):
raise ValueError("Features contains NaNs")
self.logger.debug("We are starting the magic...")
# Comes out with shape (1,8), but instead want shape (8,) so squeeze
classprobs = self.classifier.predict_proba(featarray).squeeze()
self.logger.debug("Classification complete")
# Format the output:
result = {}
for c, cla in enumerate(self.classifier.classes_):
key = self.StellarClasses(cla)
result[key] = classprobs[c]
return result, featarray
#----------------------------------------------------------------------------------------------
[docs]
def train(self, tset, savecl=True, overwrite=False):
"""
Train the Meta-classifier.
Parameters:
tset (:class:`TrainingSet`): Training set to train classifier on.
savecl (bool, optional): Save the classifier to file?
overwrite (bool, optional): Overwrite existing classifer save file.
.. codeauthor:: James S. Kuszlewicz <kuszlewicz@mps.mpg.de>
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
# Check for pre-calculated features
fitlabels = self.parse_labels(tset.labels())
# First create list of all possible classifiers:
all_classifiers = list(classifier_list)
all_classifiers.remove('meta')
# Create list of all features:
# Save this to object, we are using it to keep track of which features were used
# to train the classifier:
self.features_used = list(itertools.product(all_classifiers, self.StellarClasses))
self.features_names = [f'{classifier:s}_{stcl.name:s}' for classifier, stcl in self.features_used]
# Create table of features:
# Create as float32, since that is what RandomForestClassifier converts it to anyway.
self.logger.info("Importing features...")
features = self.build_features_table(tset.features(), total=len(tset))
# Remove columns that are all NaN:
# This can be classifiers that never returns a given class or a classifier that
# has not been run at all.
keepcols = ~allnan(features, axis=0)
features = features[:, keepcols]
self.features_used = [x for i, x in enumerate(self.features_used) if keepcols[i]]
self.features_names = [x for i, x in enumerate(self.features_names) if keepcols[i]]
# Throw an error if a classifier is not run at all:
run_classifiers = set([fu[0] for fu in self.features_used])
if run_classifiers != set(all_classifiers):
raise RuntimeError("Classifier did not contribute at all: %s" % set(all_classifiers).difference(run_classifiers))
# Raise an exception if there are NaNs left in the features:
if anynan(features):
raise ValueError("Features contains NaNs")
self.logger.info("Features imported. Shape = %s", features.shape)
# Run actual training:
self.classifier.oob_score = True
self.logger.info("Fitting model.")
self.classifier.fit(features, fitlabels)
self.logger.info('Trained. OOB Score = %s', self.classifier.oob_score_)
self.classifier.trained = True
if savecl and self.classifier.trained and self.clfile is not None:
if overwrite or not os.path.exists(self.clfile):
self.logger.info("Saving pickled classifier instance to '%s'", self.clfile)
self.save(self.clfile)