#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
A TaskManager which keeps track of which targets to process.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
.. codeauthor:: Lindsey Carboneau
.. codeauthor:: Filipe Pereira
"""
import os.path
import sqlite3
import logging
import json
from numpy import atleast_1d
from . import STATUS
#--------------------------------------------------------------------------------------------------
def _build_constraints(priority=None, starid=None, sector=None, cadence=None,
camera=None, ccd=None, cbv_area=None, return_list=False):
constraints = []
if priority is not None:
constraints.append('todolist.priority IN (' + ','.join([str(int(c)) for c in atleast_1d(priority)]) + ')')
if starid is not None:
constraints.append('todolist.starid IN (' + ','.join([str(int(c)) for c in atleast_1d(starid)]) + ')')
if sector is not None:
constraints.append('todolist.sector IN (' + ','.join([str(int(c)) for c in atleast_1d(sector)]) + ')')
if cadence == 'ffi':
constraints.append("todolist.datasource='ffi'")
elif cadence is not None:
constraints.append(f'todolist.cadence={cadence:d}')
if camera is not None:
constraints.append('todolist.camera IN (' + ','.join([str(int(c)) for c in atleast_1d(camera)]) + ')')
if ccd is not None:
constraints.append('todolist.ccd IN (' + ','.join([str(int(c)) for c in atleast_1d(ccd)]) + ')')
if cbv_area is not None:
constraints.append('todolist.cbv_area IN (' + ','.join([str(int(c)) for c in atleast_1d(cbv_area)]) + ')')
# If asked for it, return the list if constraints otherwise return string
# which fits into the other queries done by the TaskManager:
if return_list:
return constraints
return ' AND ' + ' AND '.join(constraints) if constraints else ''
#--------------------------------------------------------------------------------------------------
[docs]
class TaskManager(object):
"""
A TaskManager which keeps track of which targets to process.
"""
[docs]
def __init__(self, todo_file, cleanup=False, overwrite=False, cleanup_constraints=None,
summary=None, summary_interval=200):
"""
Initialize the TaskManager which keeps track of which targets to process.
Parameters:
todo_file (str): Path to the TODO-file.
cleanup (bool, optional): Perform cleanup/optimization of TODO-file before
during initialization. Default=False.
overwrite (bool, optional): Overwrite any previously calculated results. Default=False.
cleanup_constraints (dict, optional): Dict of constraint for cleanup of the status of
previous correction runs. If not specified, all bad results are cleaned up.
summary (str, optional): Path to JSON file which will be periodically updated with
a status summary of the corrections.
summary_interval (int, optional): Interval at which summary file is updated.
Default=100.
Raises:
FileNotFoundError: If TODO-file could not be found.
"""
if os.path.isdir(todo_file):
todo_file = os.path.join(todo_file, 'todo.sqlite')
if not os.path.isfile(todo_file):
raise FileNotFoundError('Could not find TODO-file')
if cleanup_constraints is not None and not isinstance(cleanup_constraints, (dict, list)):
raise ValueError("cleanup_constraints should be dict or list")
# Load the SQLite file:
self.conn = sqlite3.connect(todo_file)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
self.cursor.execute("PRAGMA foreign_keys=ON;")
self.cursor.execute("PRAGMA locking_mode=EXCLUSIVE;")
self.cursor.execute("PRAGMA journal_mode=TRUNCATE;")
self.summary_file = summary
self.summary_interval = summary_interval
self.summary_counter = 0
self.corrector = None
# Setup logging:
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
if not self.logger.hasHandlers():
self._loghandler = logging.StreamHandler()
self._loghandler.setFormatter(formatter)
self.logger.addHandler(self._loghandler)
# Add cadence to todolist, if it doesn't already exists:
# This is only for backwards compatibility.
self.cursor.execute("PRAGMA table_info(todolist)")
existing_columns = [r['name'] for r in self.cursor.fetchall()]
if 'cadence' not in existing_columns:
self.logger.debug("Adding CADENCE column to todolist")
self.cursor.execute("BEGIN TRANSACTION;")
self.cursor.execute("ALTER TABLE todolist ADD COLUMN cadence INTEGER DEFAULT NULL;")
self.cursor.execute("UPDATE todolist SET cadence=1800 WHERE datasource='ffi' AND sector < 27;")
self.cursor.execute("UPDATE todolist SET cadence=600 WHERE datasource='ffi' AND sector >= 27 AND sector <= 55;")
self.cursor.execute("UPDATE todolist SET cadence=120 WHERE datasource!='ffi' AND sector < 27;")
self.cursor.execute("SELECT COUNT(*) AS antal FROM todolist WHERE cadence IS NULL;")
if self.cursor.fetchone()['antal'] > 0:
self.close()
raise ValueError("TODO-file does not contain CADENCE information and it could not be determined automatically. Please recreate TODO-file.")
self.conn.commit()
# Add status indicator for corrections to todolist, if it doesn't already exists:
if 'corr_status' not in existing_columns:
self.logger.debug("Adding corr_status column to todolist")
self.cursor.execute("ALTER TABLE todolist ADD COLUMN corr_status INTEGER DEFAULT NULL")
self.cursor.execute("CREATE INDEX corr_status_idx ON todolist (corr_status);")
self.conn.commit()
# Add method_used to the diagnostics table if it doesn't exist:
self.cursor.execute("PRAGMA table_info(diagnostics)")
if 'method_used' not in [r['name'] for r in self.cursor.fetchall()]:
# Since this one is NOT NULL, we have to do some magic to fill out the
# new column after creation, by finding ketwords in other columns.
# This can be a pretty slow process, but it only has to be done once.
self.logger.debug("Adding method_used column to diagnostics")
self.cursor.execute("BEGIN TRANSACTION;")
self.cursor.execute("ALTER TABLE diagnostics ADD COLUMN method_used TEXT NOT NULL DEFAULT 'aperture';")
for m in ('aperture', 'halo', 'psf', 'linpsf'):
self.cursor.execute("UPDATE diagnostics SET method_used=? WHERE priority IN (SELECT priority FROM todolist WHERE method=?);", [m, m])
self.cursor.execute("UPDATE diagnostics SET method_used='halo' WHERE method_used='aperture' AND errors LIKE '%Automatically switched to Halo photometry%';")
self.conn.commit()
# Create indicies
self.cursor.execute("CREATE INDEX IF NOT EXISTS datavalidation_raw_approved_idx ON datavalidation_raw (approved);")
self.conn.commit()
# Create table for settings if it doesn't already exits:
self.cursor.execute("""CREATE TABLE IF NOT EXISTS corr_settings (
corrector TEXT NOT NULL
);""")
self.conn.commit()
# Load settings from setting tables:
self.cursor.execute("SELECT * FROM corr_settings LIMIT 1;")
row = self.cursor.fetchone()
if row is not None:
self.corrector = row['corrector']
# Reset the status of everything for a new run:
if overwrite:
self.cursor.execute("UPDATE todolist SET corr_status=NULL;")
self.cursor.execute("DROP TABLE IF EXISTS diagnostics_corr;")
self.cursor.execute("DELETE FROM corr_settings;")
self.conn.commit()
self.corrector = None
# Create table for diagnostics:
self.cursor.execute("""CREATE TABLE IF NOT EXISTS diagnostics_corr (
priority INTEGER PRIMARY KEY ASC NOT NULL,
lightcurve TEXT,
elaptime REAL,
worker_wait_time REAL,
variance DOUBLE PRECISION,
rms_hour DOUBLE PRECISION,
ptp DOUBLE PRECISION,
errors TEXT,
FOREIGN KEY (priority) REFERENCES todolist(priority) ON DELETE CASCADE ON UPDATE CASCADE
);""")
self.cursor.execute("CREATE UNIQUE INDEX IF NOT EXISTS diagnostics_corr_lightcurve_idx ON diagnostics_corr (lightcurve);")
self.conn.commit()
# The corrector is not stored, so try to infer it from the diagnostics information:
# This is needed on older TODO-files created before the corr_settings table
# as introduced.
if self.corrector is None:
self.cursor.execute("SELECT lightcurve FROM diagnostics_corr WHERE lightcurve IS NOT NULL LIMIT 1;")
row = self.cursor.fetchone()
if row is not None:
if '-tasoc-cbv_lc' in row['lightcurve']:
self.corrector = 'cbv'
elif '-tasoc-ens_lc' in row['lightcurve']:
self.corrector = 'ensemble'
elif '-tasoc-kf_lc' in row['lightcurve']:
self.corrector = 'kasoc_filter'
if self.corrector is not None:
self.save_settings()
# Reset calculations with status STARTED, ABORT or ERROR:
clear_status = str(STATUS.STARTED.value) + ',' + str(STATUS.ABORT.value) + ',' + str(STATUS.ERROR.value) + ',' + str(STATUS.SKIPPED.value)
constraints = ['corr_status IN (' + clear_status + ')']
# Add additional constraints from the user input and build SQL query:
if cleanup_constraints:
if isinstance(cleanup_constraints, dict):
constraints += _build_constraints(**cleanup_constraints, return_list=True)
else:
constraints += cleanup_constraints
constraints = ' AND '.join(constraints)
self.logger.debug(constraints)
self.cursor.execute("DELETE FROM diagnostics_corr WHERE priority IN (SELECT todolist.priority FROM todolist WHERE " + constraints + ");")
self.cursor.execute("UPDATE todolist SET corr_status=NULL WHERE " + constraints + ";")
self.conn.commit()
# Set all targets that did not return good photometry or were not approved by the Data Validation to SKIPPED:
self.cursor.execute(f"UPDATE todolist SET corr_status={STATUS.SKIPPED.value:d} WHERE corr_status IS NULL AND (status NOT IN ({STATUS.OK.value:d},{STATUS.WARNING.value:d}) OR todolist.priority IN (SELECT priority FROM datavalidation_raw WHERE approved=0));")
self.conn.commit()
# Analyze the tables for better query planning:
self.logger.debug("Analyzing database...")
self.cursor.execute("ANALYZE;")
# Prepare summary object:
self.summary = {
'slurm_jobid': os.environ.get('SLURM_JOB_ID', None),
'numtasks': 0,
'tasks_run': 0,
'last_error': None,
'mean_elaptime': None,
'mean_worker_waittime': None
}
# Make sure to add all the different status to summary:
for s in STATUS:
self.summary[s.name] = 0
# If we are going to output summary, make sure to fill it up:
if self.summary_file:
# Ensure it is an absolute file path:
self.summary_file = os.path.abspath(self.summary_file)
# Extract information from database:
self.cursor.execute("SELECT corr_status,COUNT(*) AS cnt FROM todolist GROUP BY corr_status;")
for row in self.cursor.fetchall():
self.summary['numtasks'] += row['cnt']
if row['corr_status'] is not None:
self.summary[STATUS(row['corr_status']).name] = row['cnt']
# Make sure the containing directory exists:
os.makedirs(os.path.dirname(self.summary_file), exist_ok=True)
# Write summary to file:
self.write_summary()
# Run a cleanup/optimization of the database before we get started:
if cleanup:
self.logger.info("Cleaning TODOLIST before run...")
tmp_isolevel = self.conn.isolation_level
try:
self.conn.isolation_level = None
self.cursor.execute("VACUUM;")
finally:
self.conn.isolation_level = tmp_isolevel
#----------------------------------------------------------------------------------------------
def __enter__(self):
return self
#----------------------------------------------------------------------------------------------
def __exit__(self, *args):
self.close()
#----------------------------------------------------------------------------------------------
def __del__(self):
self.close()
#----------------------------------------------------------------------------------------------
[docs]
def close(self):
if hasattr(self, 'cursor') and hasattr(self, 'conn') and self.conn:
try:
self.conn.rollback()
self.cursor.execute("PRAGMA journal_mode=DELETE;")
self.conn.commit()
self.cursor.close()
except sqlite3.ProgrammingError: # pragma: no cover
pass
if hasattr(self, 'conn') and self.conn:
self.conn.close()
self.conn = None
if hasattr(self, '_loghandler') and hasattr(self, 'logger') and self._loghandler:
self.logger.removeHandler(self._loghandler)
#----------------------------------------------------------------------------------------------
[docs]
def save_settings(self):
"""
Save settings to TODO-file and create method-specific columns in ``diagnostics_corr`` table.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
try:
self.cursor.execute("DELETE FROM corr_settings;")
self.cursor.execute("INSERT INTO corr_settings (corrector) VALUES (?);", [self.corrector])
# Create additional diagnostics columns based on which corrector we are running:
self.cursor.execute("PRAGMA table_info(diagnostics_corr)")
diag_columns = [r['name'] for r in self.cursor.fetchall()]
if self.corrector == 'cbv':
if 'cbv_num' not in diag_columns:
self.cursor.execute("ALTER TABLE diagnostics_corr ADD COLUMN cbv_num INTEGER DEFAULT NULL;")
elif self.corrector == 'ensemble':
if 'ens_num' not in diag_columns:
self.cursor.execute("ALTER TABLE diagnostics_corr ADD COLUMN ens_num INTEGER DEFAULT NULL;")
if 'ens_fom' not in diag_columns:
self.cursor.execute("ALTER TABLE diagnostics_corr ADD COLUMN ens_fom REAL DEFAULT NULL;")
self.conn.commit()
except: # noqa: E722, pragma: no cover
self.conn.rollback()
raise
#----------------------------------------------------------------------------------------------
[docs]
def get_number_tasks(self, priority=None, starid=None, sector=None, cadence=None,
camera=None, ccd=None, cbv_area=None):
"""
Get number of tasks due to be processed.
Parameters:
priority (int, optional): Only return task matching this priority.
starid (int, optional): Only return tasks matching this starid.
sector (int, optional): Only return tasks matching this Sector.
cadence (int, optional): Only return tasks matching this cadence.
camera (int, optional): Only return tasks matching this camera.
ccd (int, optional): Only return tasks matching this CCD.
Returns:
int: Number of tasks due to be processed.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
constraints = _build_constraints(
priority=priority,
starid=starid,
sector=sector,
cadence=cadence,
camera=camera,
ccd=ccd,
cbv_area=cbv_area)
self.cursor.execute("SELECT COUNT(*) AS num FROM todolist INNER JOIN diagnostics ON todolist.priority=diagnostics.priority WHERE corr_status IS NULL" + constraints + ";")
return int(self.cursor.fetchone()['num'])
#----------------------------------------------------------------------------------------------
[docs]
def get_task(self, priority=None, starid=None, sector=None, cadence=None,
camera=None, ccd=None, cbv_area=None, chunk=1):
"""
Get next task to be processed.
Parameters:
priority (int, optional): Only return task matching this priority.
starid (int, optional): Only return tasks matching this starid.
sector (int, optional): Only return tasks matching this Sector.
cadence (int, optional): Only return tasks matching this cadence.
camera (int, optional): Only return tasks matching this camera.
ccd (int, optional): Only return tasks matching this CCD.
chunk (int, optional): Chunk of tasks to return. Default is to not chunk (=1).
Returns:
dict, list or None: Dictionary of settings for task.
If ``chunk`` is larger than one, a list of dicts is retuned instead.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
constraints = _build_constraints(
priority=priority,
starid=starid,
sector=sector,
cadence=cadence,
camera=camera,
ccd=ccd,
cbv_area=cbv_area)
self.cursor.execute(f"SELECT * FROM todolist INNER JOIN diagnostics ON todolist.priority=diagnostics.priority WHERE corr_status IS NULL {constraints:s} ORDER BY todolist.priority LIMIT {chunk:d};")
tasks = self.cursor.fetchall()
if tasks and chunk == 1:
return dict(tasks[0])
elif tasks:
return [dict(task) for task in tasks]
return None
#----------------------------------------------------------------------------------------------
[docs]
def get_random_task(self, chunk=1):
"""
Get random task to be processed.
Parameters:
chunk (int, optional): Chunk of tasks to return. Default is to not chunk (=1).
Returns:
dict, list or None: Dictionary of settings for task.
If ``chunk`` is larger than one, a list of dicts is retuned instead.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
self.cursor.execute(f"SELECT * FROM todolist INNER JOIN diagnostics ON todolist.priority=diagnostics.priority WHERE corr_status IS NULL ORDER BY RANDOM() LIMIT {chunk:d};")
tasks = self.cursor.fetchall()
if tasks and chunk == 1:
return dict(tasks[0])
elif tasks:
return [dict(task) for task in tasks]
return None
#----------------------------------------------------------------------------------------------
[docs]
def start_task(self, tasks):
"""
Mark tasks as STARTED in the TODO-list.
Parameters:
tasks (list or dict): Task or list of tasks coming from ``get_tasks``
or ``get_random_task``.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
if isinstance(tasks, dict):
priorities = [(int(tasks['priority']),)]
else:
priorities = [(int(task['priority']),) for task in tasks]
self.cursor.executemany(f"UPDATE todolist SET corr_status={STATUS.STARTED.value:d} WHERE priority=?;", priorities)
self.summary['STARTED'] += self.cursor.rowcount
self.conn.commit()
#----------------------------------------------------------------------------------------------
[docs]
def save_results(self, results):
"""
Save result, or list of results, to TaskManager.
Parameters:
results (list or dict):
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""
if isinstance(results, dict):
results = [results]
if self.corrector is None:
self.corrector = results[0]['corrector']
self.save_settings()
additional_diags_keys = ''
Nadditional = 0
if self.corrector == 'cbv':
additional_diags_keys = ',cbv_num'
Nadditional = 1
elif self.corrector == 'ensemble':
additional_diags_keys = ',ens_num,ens_fom'
Nadditional = 2
placeholders = ','.join(['?']*(8 + Nadditional))
for result in results:
# Extract details dictionary:
details = result.get('details', {})
# The status of this target returned by the photometry:
my_status = result['status_corr']
# If the corrector has not already been set for this TODO-file,
# update the settings, and if it has check that we are not
# mixing results from different correctors in one TODO-file.
if result['corrector'] != self.corrector:
raise ValueError("Attempting to mix results from multiple correctors")
# Calculate mean elapsed time using "streaming weighted mean" with (alpha=0.1):
# https://dev.to/nestedsoftware/exponential-moving-average-on-streaming-data-4hhl
if self.summary['mean_elaptime'] is None and result.get('elaptime_corr') is not None:
self.summary['mean_elaptime'] = result['elaptime_corr']
elif result.get('elaptime_corr') is not None:
self.summary['mean_elaptime'] += 0.1 * (result['elaptime_corr'] - self.summary['mean_elaptime'])
# Save additional diagnostics:
error_msg = details.get('errors', None)
if error_msg:
error_msg = "\n".join(error_msg) if isinstance(error_msg, (list, tuple)) else error_msg.strip()
self.summary['last_error'] = error_msg
additional_diags = ()
if self.corrector == 'cbv':
additional_diags = (
details.get('cbv_num', None),
)
elif self.corrector == 'ensemble':
additional_diags = (
details.get('ens_num', None),
details.get('ens_fom', None)
)
try:
# Update the status in the TODO list:
self.cursor.execute("UPDATE todolist SET corr_status=? WHERE priority=?;", (
result['status_corr'].value,
result['priority']
))
# Save additional diagnostics:
self.cursor.execute("INSERT OR REPLACE INTO diagnostics_corr (priority,lightcurve,elaptime,worker_wait_time,variance,rms_hour,ptp,errors" + additional_diags_keys + ") VALUES (" + placeholders + ");", (
result['priority'],
result.get('lightcurve_corr', None),
result.get('elaptime_corr', None),
result.get('worker_wait_time', None),
details.get('variance', None),
details.get('rms_hour', None),
details.get('ptp', None),
error_msg
) + additional_diags)
self.conn.commit()
except: # noqa: E722, pragma: no cover
self.conn.rollback()
raise
self.summary['tasks_run'] += 1
self.summary[my_status.name] += 1
self.summary['STARTED'] -= 1
# All the results should have the same worker_waittime.
# So only update this once, using just that last result in the list:
if self.summary['mean_worker_waittime'] is None and result.get('worker_wait_time') is not None:
self.summary['mean_worker_waittime'] = result['worker_wait_time']
elif result.get('worker_wait_time') is not None:
self.summary['mean_worker_waittime'] += 0.1 * (result['worker_wait_time'] - self.summary['mean_worker_waittime'])
# Write summary file:
self.summary_counter += len(results)
if self.summary_file and self.summary_counter >= self.summary_interval:
self.summary_counter = 0
self.write_summary()
#----------------------------------------------------------------------------------------------
[docs]
def write_summary(self):
"""Write summary of progress to file. The summary file will be in JSON format."""
if self.summary_file:
try:
with open(self.summary_file, 'w') as fid:
json.dump(self.summary, fid)
except: # noqa: E722, pragma: no cover
self.logger.exception("Could not write summary file")