Source code for run_tesscorr_mpi

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Scheduler using MPI for running the TESS lightcurve corrections
pipeline on a large scale multi-core computer.

The setup uses the task-pull paradigm for high-throughput computing
using ``mpi4py``. Task pull is an efficient way to perform a large number of
independent tasks when there are more tasks than processors, especially
when the run times vary for each task.

The basic example was inspired by
https://github.com/jbornschein/mpi4py-examples/blob/master/09-task-pull.py

Example
-------
To run the program using four processes (one master and three workers) you can
execute the following command:

>>> mpiexec -n 4 python run_tesscorr_mpi.py

.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
"""

from mpi4py import MPI
import argparse
import logging
import traceback
import os
import enum
from timeit import default_timer
import corrections
from corrections.utilities import CadenceType

#--------------------------------------------------------------------------------------------------
[docs] def main(): # Parse command line arguments: parser = argparse.ArgumentParser(description='Run TESS Corrections in parallel using MPI.') parser.add_argument('-d', '--debug', help='Print debug messages.', action='store_true') parser.add_argument('-q', '--quiet', help='Only report warnings and errors.', action='store_true') parser.add_argument('-m', '--method', help='Corrector method to use.', default='cbv', choices=('ensemble', 'cbv', 'kasoc_filter')) parser.add_argument('-o', '--overwrite', help='Overwrite existing results.', action='store_true') parser.add_argument('-p', '--plot', help='Save plots when running.', action='store_true') group = parser.add_argument_group('Filter which targets to process') group.add_argument('--sector', type=int, default=None, help='TESS Sector.') group.add_argument('--cadence', type=CadenceType, choices=('ffi',1800,600,120,20), default=None, help='Cadence. Default is to run all.') group.add_argument('--camera', type=int, choices=(1,2,3,4), default=None, help='TESS Camera. Default is to run all cameras.') group.add_argument('--ccd', type=int, choices=(1,2,3,4), default=None, help='TESS CCD. Default is to run all CCDs.') parser.add_argument('input_folder', type=str, help='Input directory. This directory should contain a TODO-file and corresponding lightcurves.', nargs='?', default=None) parser.add_argument('output_folder', type=str, help='Directory to save output in.', nargs='?', default=None) args = parser.parse_args() # Set logging level: logging_level = logging.INFO if args.quiet: logging_level = logging.WARNING elif args.debug: logging_level = logging.DEBUG # Get input and output folder from environment variables: input_folder = args.input_folder if input_folder is None: input_folder = os.environ.get('TESSCORR_INPUT') if not input_folder: parser.error("Please specify an INPUT_FOLDER.") output_folder = args.output_folder if output_folder is None: output_folder = os.environ.get('TESSCORR_OUTPUT', os.path.join(os.path.dirname(input_folder), 'lightcurves')) # Define MPI message tags tags = enum.IntEnum('tags', ('INIT', 'READY', 'DONE', 'EXIT', 'START')) # Initializations and preliminaries comm = MPI.COMM_WORLD # get MPI communicator object size = comm.size # total number of processes rank = comm.rank # rank of this process status = MPI.Status() # get MPI status object if rank == 0: try: # Constraints on which targets to process: constraints = { 'sector': args.sector, 'cadence': args.cadence, 'camera': args.camera, 'ccd': args.ccd } # File path to write summary to: summary_file = os.path.join(output_folder, f'summary_corr_{args.method:s}.json') # Invoke the TaskManager to ensure that the input TODO-file has the correct columns # and indicies, which is automatically created by the TaskManager init function. with corrections.TaskManager(input_folder, cleanup=True, overwrite=args.overwrite, cleanup_constraints=constraints): pass # Signal that workers are free to initialize: comm.Barrier() # Barrier 1 # Wait for all workers to initialize: comm.Barrier() # Barrier 2 # Start TaskManager, which keeps track of the task that needs to be performed: with corrections.TaskManager(input_folder, overwrite=args.overwrite, cleanup_constraints=constraints, summary=summary_file) as tm: # Set level of TaskManager logger: tm.logger.setLevel(logging_level) # Get list of tasks: numtasks = tm.get_number_tasks(**constraints) tm.logger.info("%d tasks to be run", numtasks) # Start the master loop that will assign tasks # to the workers: num_workers = size - 1 closed_workers = 0 tm.logger.info("Master starting with %d workers", num_workers) while closed_workers < num_workers: # Get information from worker: data = comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status) source = status.Get_source() tag = status.Get_tag() if tag == tags.DONE: # The worker is done with a task tm.logger.debug("Got data from worker %d: %s", source, data) tm.save_results(data) if tag in (tags.DONE, tags.READY): # Worker is ready for a new task, so send it a task tasks = tm.get_task(**constraints, chunk=10) if tasks: tm.start_task(tasks) tm.logger.debug("Sending %d tasks to worker %d", len(tasks), source) comm.send(tasks, dest=source, tag=tags.START) else: comm.send(None, dest=source, tag=tags.EXIT) elif tag == tags.EXIT: # The worker has exited tm.logger.info("Worker %d exited.", source) closed_workers += 1 else: # This should never happen, but just to # make sure we don't run into an infinite loop: raise RuntimeError(f"Master received an unknown tag: '{tag}'") tm.logger.info("Master finishing") except: # noqa: E722, pragma: no cover # If something fails in the master print(traceback.format_exc().strip()) comm.Abort(1) else: # Worker processes execute code below # Configure logging within photometry: formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') console = logging.StreamHandler() console.setFormatter(formatter) logger = logging.getLogger('corrections') logger.addHandler(console) logger.setLevel(logging.WARNING) # Get the class for the selected method: CorrClass = corrections.corrclass(args.method) try: # Wait for signal that we are okay to initialize: comm.Barrier() # Barrier 1 # We can now safely initialize the corrector on the input file: with CorrClass(input_folder, plot=args.plot) as corr: # Wait for all workers do be done initializing: comm.Barrier() # Barrier 2 # Send signal that we are ready for task: comm.send(None, dest=0, tag=tags.READY) while True: # Receive a task from the master: tic = default_timer() tasks = comm.recv(source=0, tag=MPI.ANY_TAG, status=status) tag = status.Get_tag() toc = default_timer() if tag == tags.START: # Make sure we can loop through tasks, # even in the case we have only gotten one: results = [] if not isinstance(tasks, (list, tuple)): tasks = list(tasks) # Loop through the tasks given to us: for task in tasks: result = task.copy() # Run the correction: try: result = corr.correct(task) except: # noqa: E722 # Something went wrong error_msg = traceback.format_exc().strip() result.update({ 'status_corr': corrections.STATUS.ERROR, 'details': {'errors': [error_msg]}, }) result.update({'worker_wait_time': toc-tic}) results.append(result) # Send the result back to the master: comm.send(results, dest=0, tag=tags.DONE) elif tag == tags.EXIT: # We were told to EXIT, so lets do that break else: # This should never happen, but just to # make sure we don't run into an infinite loop: raise RuntimeError(f"Worker received an unknown tag: '{tag}'") except: # noqa: E722, pragma: no cover logger.exception("Something failed in worker") finally: comm.send(None, dest=0, tag=tags.EXIT)
if __name__ == '__main__': main()