#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
The basic correction class for the TASOC Photomety pipeline.
All other specific correction classes will inherit from BaseCorrector.
.. codeauthor:: Lindsey Carboneau
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
.. codeauthor:: Filipe Pereira
"""
import os.path
import shutil
import enum
import logging
import sqlite3
import tempfile
import contextlib
import numpy as np
from timeit import default_timer
from bottleneck import nanmedian, nanvar, allnan
from astropy.io import fits
from lightkurve import TessLightCurve
from .plots import plt, save_figure
from .quality import TESSQualityFlags, CorrectorQualityFlags
from .utilities import rms_timescale, ptp, ListHandler, LoggerWriter, fix_fits_table_headers
from .manual_filters import manual_exclude
from .version import get_version
__version__ = get_version(pep440=False)
__docformat__ = 'restructuredtext'
#--------------------------------------------------------------------------------------------------
[docs]
class STATUS(enum.Enum):
"""
Status indicator of the status of the correction.
"""
UNKNOWN = 0 #: The status is unknown. The actual calculation has not started yet.
STARTED = 6 #: The calculation has started, but not yet finished.
OK = 1 #: Everything has gone well.
ERROR = 2 #: Encountered a catastrophic error that I could not recover from.
WARNING = 3 #: Something is a bit fishy. Maybe we should try again with a different algorithm?
ABORT = 4 #: The calculation was aborted.
SKIPPED = 5 #: The target was skipped because the algorithm found that to be the best solution.
#--------------------------------------------------------------------------------------------------
def _filter_fits_hdu(hdu):
"""
Filter FITS file for invalid data (undefined timestamps).
Parameters:
hdu (:class:`astropy.io.fits.HDUList`): FITS HDUList that needs to be filtered.
Returns:
:class:`astropy.io.fits.HDUList`: Modified FITS HDUList with invalid data removed.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
# Remove non-finite timestamps
indx = np.isfinite(hdu['LIGHTCURVE'].data['TIME'])
# Remove where TIME, CADENCENO and FLUX_RAW are all exactly zero:
indx &= ~((hdu['LIGHTCURVE'].data['CADENCENO'] == 0)
& (hdu['LIGHTCURVE'].data['TIME'] == 0)
& (hdu['LIGHTCURVE'].data['FLUX_RAW'] == 0))
# Remove from in-memory FITS hdu:
hdu['LIGHTCURVE'].data = hdu['LIGHTCURVE'].data[indx]
return hdu
#--------------------------------------------------------------------------------------------------
[docs]
class BaseCorrector(object):
"""
The basic correction class for the TASOC Photometry pipeline.
All other specific correction classes will inherit from BaseCorrector.
Attributes:
plot (bool): Boolean indicating if plotting is enabled.
data_folder (str): Path to directory where axillary data for the corrector
should be stored.
.. codeauthor:: Lindsey Carboneau
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
[docs]
def __init__(self, input_folder, plot=False):
"""
Initialize the corrector.
Parameters:
input_folder (str): Directory with input files.
plot (bool, optional): Enable plotting.
Raises:
FileNotFoundError: TODO-file not found in directory.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
logger = logging.getLogger(__name__)
# Add a ListHandler to the logging of the corrections module.
# This is needed to catch any errors and warnings made by the correctors
# for ultimately storing them in the TODO-file.
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
self.message_queue = []
handler = ListHandler(message_queue=self.message_queue, level=logging.WARNING)
formatter = logging.Formatter('%(levelname)s: %(message)s')
handler.setFormatter(formatter)
logging.getLogger('corrections').addHandler(handler)
# Save inputs:
self.plot = plot
if os.path.isdir(input_folder):
self.input_folder = input_folder
todo_file = os.path.join(input_folder, 'todo.sqlite')
else:
self.input_folder = os.path.dirname(input_folder)
todo_file = input_folder
# The path to the TODO list:
logger.debug("TODO file: %s", todo_file)
if not os.path.isfile(todo_file):
raise FileNotFoundError("TODO file not found")
self.CorrMethod = {
'BaseCorrector': 'base',
'EnsembleCorrector': 'ensemble',
'CBVCorrector': 'cbv',
'CBVCreator': 'cbv',
'KASOCFilterCorrector': 'kasoc_filter'
}.get(self.__class__.__name__)
# Find the axillary data directory based on which corrector is running:
if self.CorrMethod == 'base':
self.data_folder = os.path.join(os.path.dirname(__file__), 'data')
else:
# Create a data folder specific to this corrector:
if self.CorrMethod == 'cbv':
self.data_folder = os.path.join(self.input_folder, 'cbv-prepare')
else:
self.data_folder = os.path.join(os.path.dirname(__file__), 'data', self.CorrMethod)
# Make sure that the folder exists:
os.makedirs(self.data_folder, exist_ok=True)
# Create readonly copy of the TODO-file:
with tempfile.NamedTemporaryFile(dir=self.input_folder, suffix='.sqlite', delete=False) as tmpfile:
self.todo_file_readonly = tmpfile.name
with open(todo_file, 'rb') as fid:
shutil.copyfileobj(fid, tmpfile)
tmpfile.flush()
# Open the SQLite file in read-only mode:
self.conn = sqlite3.connect('file:' + self.todo_file_readonly + '?mode=ro', uri=True)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
#----------------------------------------------------------------------------------------------
def __enter__(self):
return self
#----------------------------------------------------------------------------------------------
def __exit__(self, *args):
self.close()
self._close_basecorrector()
#----------------------------------------------------------------------------------------------
def __del__(self):
self.close()
self._close_basecorrector()
#----------------------------------------------------------------------------------------------
[docs]
def close(self):
"""Close correction object."""
pass
#----------------------------------------------------------------------------------------------
def _close_basecorrector(self):
"""Close BaseCorrection object."""
if hasattr(self, 'cursor') and hasattr(self, 'conn') and self.cursor is not None:
try:
self.conn.rollback()
self.cursor.close()
self.cursor = None
except sqlite3.ProgrammingError:
pass
if hasattr(self, 'conn') and self.conn is not None:
self.conn.close()
self.conn = None
if hasattr(self, 'todo_file_readonly') and os.path.isfile(self.todo_file_readonly):
os.remove(self.todo_file_readonly)
#----------------------------------------------------------------------------------------------
[docs]
def plot_folder(self, lc):
"""
Return folder path where plots for a given lightcurve should be saved.
Parameters:
lc (:class:`lightkurve.TessLightCurve`): Lightcurve to return plot path for.
Returns:
string: Path to directory where plots should be saved.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
lcfile = os.path.join(self.input_folder, lc.meta['task']['lightcurve'])
plot_folder = os.path.join(os.path.dirname(lcfile), 'plots', '%011d' % lc.targetid)
if self.plot:
os.makedirs(plot_folder, exist_ok=True)
return plot_folder
#----------------------------------------------------------------------------------------------
[docs]
def do_correction(self, lightcurve):
"""
Apply corrections to target lightcurve.
Parameters:
lightcurve (:class:`lightkurve.TessLightCurve`): Lightcurve of the target star
to be corrected.
Returns:
tuple:
- :class:`STATUS`: The status of the corrections.
- :class:`lightkurve.TessLightCurve`: corrected lightcurve object.
Raises:
NotImplementedError
"""
raise NotImplementedError("A helpful error message goes here")
#----------------------------------------------------------------------------------------------
[docs]
def correct(self, task, output_folder=None):
"""
Run correction.
Parameters:
task (dict): Dictionary defining a task/lightcurve to process.
output_folder (str, optional): Path to directory where lightcurve should be saved.
Returns:
dict: Result dictionary containing information about the processing.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
logger = logging.getLogger(__name__)
t1 = default_timer()
error_msg = []
details = {}
save_file = None
result = task.copy()
try:
# Load the lightcurve
lc = self.load_lightcurve(task)
# Run the correction on this lightcurve:
lc_corr, status = self.do_correction(lc)
except (KeyboardInterrupt, SystemExit): # pragma: no cover
status = STATUS.ABORT
logger.warning("Correction was aborted (priority=%d)", task['priority'])
except: # noqa: E722 pragma: no cover
status = STATUS.ERROR
logger.exception("Correction failed (priority=%d)", task['priority'])
# Check that the status has been changed:
if status == STATUS.UNKNOWN: # pragma: no cover
raise ValueError("STATUS was not set by do_correction")
# Do sanity checks:
if status in (STATUS.OK, STATUS.WARNING):
# Make sure all NaN fluxes have corresponding NaN errors:
lc_corr.flux_err[np.isnan(lc_corr.flux)] = np.NaN
# Simple check that entire lightcurve is not NaN:
if allnan(lc_corr.flux):
logger.error("Final lightcurve is all NaNs")
status = STATUS.ERROR
if allnan(lc_corr.flux_err):
logger.error("Final lightcurve errors are all NaNs")
status = STATUS.ERROR
if np.any(np.isinf(lc_corr.flux)):
logger.error("Final lightcurve contains Inf")
status = STATUS.ERROR
if np.any(np.isinf(lc_corr.flux_err)):
logger.error("Final lightcurve errors contains Inf")
status = STATUS.ERROR
# Calculate diagnostics:
if status in (STATUS.OK, STATUS.WARNING):
# Calculate diagnostics:
details['variance'] = nanvar(lc_corr.flux, ddof=1)
details['rms_hour'] = rms_timescale(lc_corr, timescale=3600/86400)
details['ptp'] = ptp(lc_corr)
# Diagnostics specific to the method:
if self.CorrMethod == 'cbv':
details['cbv_num'] = lc_corr.meta['additional_headers']['CBV_NUM']
elif self.CorrMethod == 'ensemble':
details['ens_num'] = lc_corr.meta['additional_headers']['ENS_NUM']
details['ens_fom'] = lc_corr.meta['FOM']
# Save the lightcurve to file:
try:
save_file = self.save_lightcurve(lc_corr, output_folder=output_folder)
except (KeyboardInterrupt, SystemExit): # pragma: no cover
status = STATUS.ABORT
logger.warning("Correction was aborted (priority=%d)", task['priority'])
except: # noqa: E722 pragma: no cover
status = STATUS.ERROR
logger.exception("Could not save lightcurve file (priority=%d)", task['priority'])
# Plot the final lightcurve:
if self.plot:
fig = plt.figure(dpi=200)
ax = fig.add_subplot(111)
ax.scatter(lc.time, 1e6*(lc.flux/nanmedian(lc.flux)-1), s=2, alpha=0.3, marker='o', label="Original")
ax.scatter(lc_corr.time, lc_corr.flux, s=2, alpha=0.3, marker='o', label="Corrected")
ax.set_xlabel('Time (TBJD)')
ax.set_ylabel('Relative flux (ppm)')
ax.legend()
save_figure(os.path.join(self.plot_folder(lc), self.CorrMethod + '_final'), fig=fig)
plt.close(fig)
# Unpack any errors or warnings that were sent to the logger during the correction:
if self.message_queue:
error_msg += self.message_queue
self.message_queue.clear()
if not error_msg:
error_msg = None
# Update results:
t2 = default_timer()
details['errors'] = error_msg
result.update({
'corrector': self.CorrMethod,
'status_corr': status,
'elaptime_corr': t2-t1,
'lightcurve_corr': save_file,
'details': details
})
return result
#----------------------------------------------------------------------------------------------
[docs]
def search_database(self, select=None, join=None, search=None, order_by=None, limit=None,
distinct=False):
"""
Search list of lightcurves and return a list of tasks/stars matching the given criteria.
Returned rows are restricted to things not marked as ``STATUS.SKIPPED``, since these have
been deemed too bad to not require corrections, they are definitely also too bad to use in
any kind of correction.
Parameters:
select (list of strings or None): List of table columns to return.
search (list of strings or None): Conditions to apply to the selection of stars from the database
order_by (list, str or None): Column to order the database output by.
limit (int or None): Maximum number of rows to retrieve from the database.
If limit is None, all the rows are retrieved.
distinct (bool): Boolean indicating if the query should return unique elements only.
join (list): Table join commands to merge several database tables together.
Returns:
list: All stars retrieved by the call to the database as dicts/tasks
that can be consumed directly by load_lightcurve
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
logger = logging.getLogger(__name__)
if select is None:
select = '*'
elif isinstance(select, (list, tuple)):
select = ",".join(select)
joins = ['INNER JOIN diagnostics ON todolist.priority=diagnostics.priority']
if join is None:
pass
elif isinstance(join, (list, tuple)):
joins += list(join)
else:
joins.append(join)
joins = ' '.join(joins)
if search is None:
search = ''
elif isinstance(search, (list, tuple)):
search = "AND " + " AND ".join(search)
else:
search = 'AND ' + search
if order_by is None:
order_by = ''
elif isinstance(order_by, (list, tuple)):
order_by = " ORDER BY " + ",".join(order_by)
elif isinstance(order_by, str):
order_by = " ORDER BY " + order_by
limit = '' if limit is None else " LIMIT %d" % limit
query = "SELECT {distinct:s}{select:s} FROM todolist {join:s} WHERE (corr_status IS NULL OR corr_status!={skipped:d}) {search:s}{order_by:s}{limit:s};".format(
distinct='DISTINCT ' if distinct else '',
select=select,
join=joins,
skipped=STATUS.SKIPPED.value,
search=search,
order_by=order_by,
limit=limit
)
logger.debug("Running query: %s", query)
# Ask the database:
self.cursor.execute(query)
return [dict(row) for row in self.cursor.fetchall()]
#----------------------------------------------------------------------------------------------
[docs]
def load_lightcurve(self, task):
"""
Load lightcurve from task ID or full task dictionary.
Parameters:
task (integer or dict):
Returns:
:class:`lightkurve.TessLightCurve`: Lightcurve for the star in question.
Raises:
ValueError: On invalid file format.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
logger = logging.getLogger(__name__)
# Find the relevant information in the TODO-list:
if not isinstance(task, dict) or task.get("lightcurve") is None:
if isinstance(task, dict):
priority = int(task['priority'])
else:
priority = int(task)
self.cursor.execute("SELECT * FROM todolist INNER JOIN diagnostics ON todolist.priority=diagnostics.priority WHERE todolist.priority=? LIMIT 1;", (priority, ))
task = self.cursor.fetchone()
if task is None:
raise ValueError("Priority could not be found in the TODO list")
task = dict(task)
# Get the path of the FITS file:
fname = os.path.join(self.input_folder, task.get('lightcurve'))
logger.debug('Loading lightcurve: %s', fname)
# Load lightcurve file and create a TessLightCurve object:
if fname.endswith(('.fits.gz', '.fits')):
with fits.open(fname, mode='readonly', memmap=True) as hdu:
# Filter out invalid parts of the input lightcurve:
hdu = _filter_fits_hdu(hdu)
# Quality flags from the pixels:
pixel_quality = np.asarray(hdu['LIGHTCURVE'].data['PIXEL_QUALITY'], dtype='int32')
# Corrections applied to timestamps:
timecorr = hdu['LIGHTCURVE'].data['TIMECORR']
# Create the QUALITY column and fill it with flags of bad data points:
quality = np.zeros_like(hdu['LIGHTCURVE'].data['TIME'], dtype='int32')
bad_data = ~np.isfinite(hdu['LIGHTCURVE'].data['FLUX_RAW'])
bad_data |= (pixel_quality & TESSQualityFlags.DEFAULT_BITMASK != 0)
quality[bad_data] |= CorrectorQualityFlags.FlaggedBadData
# Create lightkurve object:
lc = TessLightCurve(
time=hdu['LIGHTCURVE'].data['TIME'],
flux=hdu['LIGHTCURVE'].data['FLUX_RAW'],
flux_err=hdu['LIGHTCURVE'].data['FLUX_RAW_ERR'],
centroid_col=hdu['LIGHTCURVE'].data['MOM_CENTR1'],
centroid_row=hdu['LIGHTCURVE'].data['MOM_CENTR2'],
quality=quality,
cadenceno=np.asarray(hdu['LIGHTCURVE'].data['CADENCENO'], dtype='int32'),
time_format='btjd',
time_scale='tdb',
targetid=hdu[0].header.get('TICID'),
label=hdu[0].header.get('OBJECT'),
camera=hdu[0].header.get('CAMERA'),
ccd=hdu[0].header.get('CCD'),
sector=hdu[0].header.get('SECTOR'),
ra=hdu[0].header.get('RA_OBJ'),
dec=hdu[0].header.get('DEC_OBJ'),
quality_bitmask=CorrectorQualityFlags.DEFAULT_BITMASK,
meta={'data_rel': hdu[0].header.get('DATA_REL')}
)
# Apply manual exclude flag:
manexcl = manual_exclude(lc)
lc.quality[manexcl] |= CorrectorQualityFlags.ManualExclude
elif fname.endswith(('.noisy', '.sysnoise')): # pragma: no cover
data = np.loadtxt(fname)
# Quality flags from the pixels:
pixel_quality = np.asarray(data[:,3], dtype='int32')
# Corrections applied to timestamps:
timecorr = np.zeros(data.shape[0], dtype='float32')
# Change the Manual Exclude flag, since the simulated data
# and the real TESS quality flags differ in the definition:
indx = (pixel_quality & 256 != 0)
pixel_quality[indx] -= 256
pixel_quality[indx] |= TESSQualityFlags.ManualExclude
# Create the QUALITY column and fill it with flags of bad data points:
quality = np.zeros(data.shape[0], dtype='int32')
bad_data = ~np.isfinite(data[:,1])
bad_data |= (pixel_quality & TESSQualityFlags.DEFAULT_BITMASK != 0)
quality[bad_data] |= CorrectorQualityFlags.FlaggedBadData
# Create lightkurve object:
lc = TessLightCurve(
time=data[:,0],
flux=data[:,1],
flux_err=data[:,2],
quality=quality,
cadenceno=np.arange(1, data.shape[0]+1, dtype='int32'),
time_format='jd',
time_scale='tdb',
targetid=task['starid'],
label="Star%d" % task['starid'],
camera=task['camera'],
ccd=task['ccd'],
sector=2,
#ra=0,
#dec=0,
quality_bitmask=CorrectorQualityFlags.DEFAULT_BITMASK,
meta={}
)
else:
raise ValueError("Invalid file format")
# Add additional attributes to lightcurve object:
lc.pixel_quality = pixel_quality
lc.timecorr = timecorr
# Modify the "extra_columns" tuple of the lightkurve object:
# This is used internally in lightkurve to keep track of the columns in the
# object, and make sure they are propergated.
lc.extra_columns = tuple(list(lc.extra_columns) + ['timecorr', 'pixel_quality'])
# Keep the original task in the metadata:
lc.meta['task'] = task
lc.meta['additional_headers'] = fits.Header()
if logger.isEnabledFor(logging.DEBUG):
with contextlib.redirect_stdout(LoggerWriter(logger, logging.DEBUG)):
lc.show_properties()
return lc
#----------------------------------------------------------------------------------------------
[docs]
def save_lightcurve(self, lc, output_folder=None):
"""
Save generated lightcurve to file.
Parameters:
output_folder (str, optional): Path to directory where to save lightcurve.
If ``None`` the directory specified in the attribute ``output_folder`` is used.
Returns:
str: Path to the generated file.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
.. codeauthor:: Mikkel N. Lund <mikkelnl@phys.au.dk>
"""
logger = logging.getLogger(__name__)
# Find the name of the correction method based on the class name:
CorrMethod = {
'EnsembleCorrector': 'Ensemble',
'CBVCorrector': 'CBV',
'KASOCFilterCorrector': 'KASOC Filter'
}.get(self.__class__.__name__)
# Decide where to save the finished lightcurve:
if output_folder is None:
output_folder = self.input_folder
# Get the filename of the original file from the task:
fname = lc.meta.get('task').get('lightcurve')
if fname.endswith(('.fits.gz', '.fits')):
logger.debug("Saving as FITS file")
if self.CorrMethod == 'cbv':
filename = os.path.basename(fname).replace('-tasoc_lc', '-tasoc-cbv_lc')
if self.CorrMethod == 'ensemble':
filename = os.path.basename(fname).replace('-tasoc_lc', '-tasoc-ens_lc')
if self.CorrMethod == 'kasoc_filter':
filename = os.path.basename(fname).replace('-tasoc_lc', '-tasoc-kf_lc')
if output_folder != self.input_folder:
save_file = os.path.join(output_folder, filename)
else:
save_file = os.path.join(output_folder, os.path.dirname(fname), filename)
logger.debug("Saving lightcurve to '%s'", save_file)
# Open the FITS file to overwrite the corrected flux columns:
with fits.open(os.path.join(self.input_folder, fname), mode='readonly') as hdu:
# Filter out invalid parts of the input lightcurve:
hdu = _filter_fits_hdu(hdu)
# Overwrite the corrected flux columns:
hdu['LIGHTCURVE'].data['FLUX_CORR'] = lc.flux
hdu['LIGHTCURVE'].data['FLUX_CORR_ERR'] = lc.flux_err
hdu['LIGHTCURVE'].data['QUALITY'] = lc.quality
# Set headers about the correction:
hdu['LIGHTCURVE'].header['CORRMET'] = (CorrMethod, 'Lightcurve correction method')
hdu['LIGHTCURVE'].header['CORRVER'] = (__version__, 'Version of correction pipeline')
# Set additional headers provided by the individual methods:
if lc.meta['additional_headers']:
for key, value in lc.meta['additional_headers'].items():
hdu['LIGHTCURVE'].header[key] = (value, lc.meta['additional_headers'].comments[key])
# For Ensemble, also add the ensemble list to the FITS file:
if self.CorrMethod == 'ensemble' and hasattr(self, 'ensemble_starlist'):
# Create binary table to hold the list of ensemble stars:
c1 = fits.Column(name='TIC', format='K', array=self.ensemble_starlist['starids'])
c2 = fits.Column(name='BZETA', format='E', array=self.ensemble_starlist['bzetas'])
wm = fits.BinTableHDU.from_columns([c1, c2], name='ENSEMBLE')
wm.header['TDISP1'] = 'I10'
wm.header['TDISP2'] = 'E26.17'
fix_fits_table_headers(wm, {
'TIC': 'TIC identifier',
'BZETA': 'background scale'
})
# Add the new table to the list of HDUs:
hdu.append(wm)
# Write the modified HDUList to the new filename:
hdu.writeto(save_file, checksum=True, overwrite=True)
# For the simulated ASCII files, simply create a new ASCII files next to the original one,
# with an extension ".corr":
elif fname.endswith(('.noisy', '.sysnoise')): # pragma: no cover
save_file = os.path.join(output_folder, os.path.dirname(fname), os.path.splitext(os.path.basename(fname))[0] + '.corr')
# Create new ASCII file:
with open(save_file, 'w') as fid:
fid.write("# TESS Asteroseismic Science Operations Center\n")
fid.write("# TIC identifier: %d\n" % lc.targetid)
fid.write("# Sector: %s\n" % lc.sector)
fid.write("# Correction method: %s\n" % CorrMethod)
fid.write("# Correction Version: %s\n" % __version__)
if lc.meta['additional_headers']:
for key, value in lc.meta['additional_headers'].items():
fid.write("# %-18s: %s\n" % (key, value))
fid.write("#\n")
fid.write("# Column 1: Time (days)\n")
fid.write("# Column 2: Corrected flux (ppm)\n")
fid.write("# Column 3: Corrected flux error (ppm)\n")
fid.write("# Column 4: Quality flags\n")
fid.write("#-------------------------------------------------\n")
for k in range(len(lc.time)):
fid.write("%f %.16e %.16e %d\n" % (
lc.time[k],
lc.flux[k],
lc.flux_err[k],
lc.quality[k]
))
fid.write("#-------------------------------------------------\n")
# Store the output file in the details object for future reference:
save_file = os.path.relpath(save_file, output_folder).replace('\\', '/')
return save_file