Source code for starclass.plots

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Plotting utilities for stellar classification.

.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""

import logging
import warnings
import os.path
import numpy as np
import matplotlib
from matplotlib.ticker import MultipleLocator
import matplotlib.pyplot as plt

with warnings.catch_warnings():
	warnings.filterwarnings('ignore', module='shap', message="IPython could not be loaded!")
	from shap import summary_plot

#--------------------------------------------------------------------------------------------------
[docs] def plots_interactive(backend=('QtAgg', 'Qt5Agg', 'MacOSX', 'Qt4Agg', 'Qt5Cairo', 'TkAgg', 'GTK4Agg')): """ Change plotting to using an interactive backend. Parameters: backend (str or list): Backend to change to. If not provided, will try different interactive backends and use the first one that works. .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk> """ logger = logging.getLogger(__name__) logger.debug("Valid interactive backends: %s", matplotlib.rcsetup.interactive_bk) if isinstance(backend, str): backend = [backend] for bckend in backend: if bckend not in matplotlib.rcsetup.interactive_bk: logger.debug("Interactive backend '%s' is not found", bckend) continue # Try to change the backend, and catch errors if it didn't work: try: plt.switch_backend(bckend) except (ModuleNotFoundError, ImportError): pass else: logger.debug("Interactive backend selected: %s", bckend) break
#--------------------------------------------------------------------------------------------------
[docs] def plots_noninteractive(): """ Change plotting to using a non-interactive backend, which can e.g. be used on a cluster. Will set backend to 'Agg'. .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk> """ plt.switch_backend('Agg')
#--------------------------------------------------------------------------------------------------
[docs] def plot_confusion_matrix(diagnostics=None, cfmatrix=None, ticklabels=None, ax=None, cmap='Blues', style=None): """ Plot a confusion matrix. If both ``diagnostics`` and ``cfmatrix`` or ``ticklabels`` are provided, the last two will take precedence. Parameters: diagnostics (dict, optional): Diagnostics to load confusion matrix from. Is created during testing and can be loaded from the diagnostics JSON files. cfmatrix (ndarray, [n_classes x n_classes]): Confusion matrix. ticklabels (list, [n_classes]): labels for plot axes. ax (:py:class:`matplotlib.pyplot.Axes`): cmap (str, optional): style (str, optional): .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk> """ logger = logging.getLogger(__name__) if diagnostics is None and cfmatrix is None: raise ValueError("One of DIAGNOSTICS or CFMATRIX must be provided.") # Pull things out from the diagnostics dict: if diagnostics: if cfmatrix is None: cfmatrix = diagnostics['confusion_matrix'] if ticklabels is None: ticklabels = [s['value'] for s in diagnostics['classes']] cfmatrix = np.asarray(cfmatrix, dtype='float64') N = cfmatrix.shape[0] norms = np.sum(cfmatrix, axis=1) for i in range(N): if norms[i] > 0: cfmatrix[i, :] /= norms[i] # Warn if we don't have any labels to show, and create some dummy labels: if diagnostics is None and ticklabels is None: logger.warning("No class names were provided for confusion matrix. Assigning dummy-labels for plotting.") ticklabels = ['#{k:d}' for k in range(N)] if style is None: style = os.path.abspath(os.path.join(os.path.dirname(__file__), 'starclass.mplstyle')) with plt.style.context(style): if ax is None: fig, ax = plt.subplots() else: fig = ax.figure ax.imshow(cfmatrix, interpolation='nearest', origin='lower', cmap=cmap) text_settings = {'va': 'center', 'ha': 'center', 'fontsize': 14} for x in range(N): for y in range(N): if cfmatrix[y,x] > 0.7: ax.text(x, y, "%d" % np.round(cfmatrix[y,x]*100), color='w', **text_settings) elif cfmatrix[y,x] < 0.01 and cfmatrix[y,x] > 0: ax.text(x, y, "<1", **text_settings) elif cfmatrix[y,x] > 0: ax.text(x, y, "%d" % np.round(cfmatrix[y,x]*100), **text_settings) for x in np.arange(cfmatrix.shape[0]): ax.plot([x+0.5,x+0.5], [-0.5,N-0.5], ':', color='0.5', lw=0.5) ax.plot([-0.5,N-0.5], [x+0.5,x+0.5], ':', color='0.5', lw=0.5) ax.set_xlim(-0.5, N-0.5) ax.set_ylim(-0.5, N-0.5) ax.set_xlabel('Predicted Class', fontsize=18) ax.set_ylabel('True Class', fontsize=18) if diagnostics is not None: ax.set_title(diagnostics.get('classifier', '') + ' - ' + diagnostics.get('tset', '') + ' - ' + diagnostics.get('level', '')) # Class labels: plt.xticks(np.arange(N), ticklabels, rotation='vertical') plt.yticks(np.arange(N), ticklabels) ax.tick_params(axis='both', which='major', labelsize=18) return fig
#--------------------------------------------------------------------------------------------------
[docs] def plot_roc_curve(diagnostics, ax=None, style=None): """ Plot Receiver Operating Characteristic (ROC) curve. Parameters: diagnostics (dict): Diagnostics coming from :py:func:`utilities.roc_curve` or saved to file during :py:func:`BaseClassifier.test`. ax (:py:class:`matplotlib.pyplot.Axes`): style (str, optional): See also: :py:func:`utilities.roc_curve` .. codeauthor:: Jeroen Audenaert <jeroen.audenaert@kuleuven.be> .. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk> """ # Pull things out from the diagnostics dict: fpr = diagnostics['false_positive_rate'] tpr = diagnostics['true_positive_rate'] roc_auc = diagnostics['roc_auc'] idx = diagnostics['roc_threshold_index'] classes = diagnostics['classes'] if style is None: style = os.path.abspath(os.path.join(os.path.dirname(__file__), 'starclass.mplstyle')) with plt.style.context(style): if ax is None: fig, ax = plt.subplots() else: fig = ax.figure # Reference line for a pure random classifier: ax.plot([0, 1], [0, 1], color='k', lw=0.5, linestyle='--') # Plot individual classes: lw = 1 for c in classes: cname = c['name'] cvalue = c['value'] ax.plot(fpr[cname], tpr[cname], label=f'{cvalue:s} (area = {roc_auc[cname]:.4f})', lw=lw) ax.scatter(fpr[cname][idx[cname]], tpr[cname][idx[cname]], marker='o') ax.plot(fpr['micro'], tpr['micro'], lw=lw, label=f"micro avg (area = {roc_auc['micro']:.4f})") ax.set_xlim(-0.05, 1.05) ax.set_ylim(-0.05, 1.05) ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title('ROC Curve - ' + diagnostics.get('classifier', '') + ' - ' + diagnostics.get('tset', '') + ' - ' + diagnostics.get('level', '')) ax.legend(loc="lower right") ax.xaxis.set_major_locator(MultipleLocator(0.1)) ax.xaxis.set_minor_locator(MultipleLocator(0.05)) ax.yaxis.set_major_locator(MultipleLocator(0.1)) ax.yaxis.set_minor_locator(MultipleLocator(0.05)) return fig
#--------------------------------------------------------------------------------------------------
[docs] def plot_feature_importance(shap_values, features, features_names, class_names, ax=None, style=None): if style is None: style = os.path.abspath(os.path.join(os.path.dirname(__file__), 'starclass.mplstyle')) with plt.style.context(style): with warnings.catch_warnings(): warnings.filterwarnings('ignore', message='Attempted to set non-positive bottom ylim on a log-scaled axis') summary_plot(shap_values, features=features, feature_names=features_names, class_names=class_names, max_display=len(features_names), plot_type='bar', show=False) # SHAP creates it's own figures, but doesn't return them, # so ask matplotlib what the latest figure is: fig = plt.gcf() ax = fig.axes[0] ax.spines['right'].set_visible(True) ax.spines['top'].set_visible(True) ax.spines['left'].set_visible(True) return fig
#--------------------------------------------------------------------------------------------------
[docs] def plot_feature_scatter_density(shap_values, features, features_names, class_name, ax=None, style=None): if style is None: style = os.path.abspath(os.path.join(os.path.dirname(__file__), 'starclass.mplstyle')) with plt.style.context(style): summary_plot(shap_values, features=features, feature_names=features_names, class_names=class_name, plot_type='dot', show=False) # SHAP creates it's own figures, but doesn't return them, # so ask matplotlib what the latest figure is: fig = plt.gcf() ax = fig.axes[0] ax.set_title(class_name) ax.spines['right'].set_visible(True) ax.spines['top'].set_visible(True) ax.spines['left'].set_visible(True) return fig