#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
KASOC Filter for Asteroseismic Data Preparation
Corrects Kepler/K2/TESS data for instrumental effects and planetary signals
to create new datasets optimized for asteroseismic analysis.
.. codeauthor:: Rasmus Handberg <rasmush@phys.au.dk>
.. codeauthor:: Mikkel N. Lund
"""
import logging
import numpy as np
from numpy import zeros, empty, argsort, diff, mod, isfinite, array, append, searchsorted, NaN, Inf
from copy import deepcopy
import os.path
from bottleneck import nanmedian, nanstd, median, nanargmax, nansum, allnan, nanmin, nanmax
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from scipy.stats import norm
from statsmodels.nonparametric.smoothers_lowess import lowess
from .utilities import moving_nanmedian, moving_nanmedian_cyclic, smooth, smooth_cyclic, BIC, theil_sen, gap_fill
from ..plots import plt, matplotlib
from ..quality import TESSQualityFlags
from ..utilities import mad_to_sigma
import warnings
#==============================================================================
# Constants
#==============================================================================
# Directory to save output figures to:
_output_folder = None
_output_prefix = ''
_output_format = 'png'
#==============================================================================
# Settings:
#==============================================================================
[docs]
def set_output(folder=None, prefix='', fmt='png', native=False):
"""Change the output settings for the plots generated by the filter."""
global _output_folder, _output_prefix, _output_format
_output_prefix = prefix
if native:
_output_folder = 'dummy'
_output_format = 'native'
else:
_output_folder = folder
_output_format = fmt
#==============================================================================
# Utility functions
#==============================================================================
[docs]
def scale_timescales(numax, min_value_long=30.0):
"""
Scale the filter timescales with estimated nu_max of the star in question to avoid disturbing oscillation signals for low-numax stars.
Parameters:
numax (float): nu_max in microHz.
min_value_long (float, optional): The shortest filter width to return in days. Default=30.
Returns:
float: New timescale which will avoid the specified nu_max region.
"""
# If invalid numax is given, return default filter value:
if numax is None or numax <= 0:
return min_value_long
# Estimate the FWHM and sigma of the oscillation envelope:
fwhm = 0.66*numax**0.88
sigma = fwhm/(2*np.sqrt(2*np.log(2)))
# Calculate filter frequency as a constant factor below the
# lower limit of the oscillation envelope:
nulong = 0.10*(numax - 2*sigma)
nulong = np.minimum(nulong, 1e6/(min_value_long*86400))
# Calculate the corresponding timescale in days:
timescale_long = 1.0/(nulong*86400e-6)
# Return the scaled results:
return timescale_long
#==============================================================================
# Custom exceptions and warnings
#==============================================================================
[docs]
class InvalidSigmasWarning(UserWarning):
pass
#==============================================================================
# Main filtering functions
#==============================================================================
#--------------------------------------------------------------------------------------------------
[docs]
def remove_jumps(t, x, jumps, width=3.0, return_flags=False):
"""
Remove jumps from timeseries.
Parameters:
t (ndarray): Time vector (days). Must be sorted in time.
x (ndarray): Flux vector. Can contain invalid points (NaN).
jumps (list): Vector of timestamps where jumps are to be corrected.
width (float): Width of the region on each side of jumps to compare (default=3 days).
return_flags (boolean): Return two additional arrays with location of corrected jumps.
Returns:
tuple:
- ndarray: Corrected flux vector.
- list: List with the same length as ``jumps``, indicating if the particular jump was corrected.
- ndarray: Quality array with same length as ``x``, indicating where and which correction was performed.
"""
# Get the logger to use for printing messages:
logger = logging.getLogger(__name__)
# Number of points:
N = len(t)
dt = nanmedian(diff(t))
# Convert a simple list of times to a jumps-dictionary:
jumps = np.atleast_1d(jumps)
for k, jump in enumerate(jumps):
if np.isscalar(jump):
jumps[k] = {'time': jump}
elif not isinstance(jump, dict):
raise ValueError("Invalid input in JUMPS")
# Important that we correct the jumps in the right order:
jumps = sorted(jumps, key=lambda k: k['time'])
# Arrays needed for the following:
correction = np.empty(2, dtype='float64')
if return_flags:
flag_jumps = [False]*len(jumps)
flag_jumps2 = np.zeros(N, dtype='int64')
# Correct jumps one after the other:
kj = 0
for k, jump in enumerate(jumps):
logger.debug(jump)
# Extract information about jump:
tjump = jump.get('time')
jumptype = jump.get('type', 'multiplicative')
jumpforce = jump.get('force', False)
# Make maps to central region and region after jump:
kj_pre = kj
kj = searchsorted(t, tjump)
if kj == 0 or kj == N or kj == kj_pre: continue # Stop if first, last or same point as previous
central1 = searchsorted(t, t[kj-1]-width)
central2 = searchsorted(t, t[kj]+width)
gapsize = t[kj] - t[kj-1] # The length of the jump
# Make small timeseries around the gap:
tcen = t[central1:central2]
xcen = x[central1:central2]
xmdl = np.empty_like(xcen)
indx = searchsorted(tcen, tjump)
# Do simple check to see if all datapoints are NaN:
if allnan(x[central1:kj]) or allnan(x[kj:central2]):
continue
# Run LOWESS filter on two halves to eliminate effects of transit:
if kj-central1 < 0.5*int(width/dt):
w1 = np.column_stack((t[central1:kj], x[central1:kj]))
else:
w1 = lowess(x[central1:kj], t[central1:kj], frac=1./3, is_sorted=True)
if central2-kj < 0.5*int(width/dt):
w2 = np.column_stack((t[kj:central2], x[kj:central2]))
else:
w2 = lowess(x[kj:central2], t[kj:central2], frac=1./3, is_sorted=True)
# Calculate median levels before and after jump
# and make these match up:
level1_const = nanmedian(w1[:,1])
level2_const = nanmedian(w2[:,1])
# Do not try to use linear relation on very long gaps
# it will in many cases not work.
if gapsize < 2*width:
# Do robust linear fit of part before and after jump:
res1 = theil_sen(w1[:,0], w1[:,1], n_samples=100000)
res2 = theil_sen(w2[:,0], w2[:,1], n_samples=100000)
# Evaluate fitted lines at midpoint in the gap:
tmid = (t[kj] + t[kj-1])/2 # Midpoint in gap
level1_linear = np.polyval(res1, tmid)
level2_linear = np.polyval(res2, tmid)
else:
level1_linear = NaN
level2_linear = NaN
# Calculate Bayesian Information Criterion (BIC) for the different
# models of the jump to decide which one should be applied to the data:
if jumptype == 'additive':
# Constant model:
correction[0] = level1_const - level2_const
if isfinite(correction[0]):
# Calculate model:
xmdl[:indx] = level1_const
xmdl[indx:] = level2_const
# Calculate BIC:
s1 = BIC(xcen, xmdl, 2)
else:
s1 = Inf
# Linear model:
correction[1] = level1_linear - level2_linear
if isfinite(correction[1]):
# Calculate model:
xmdl[:indx] = np.polyval(res1, tcen[:indx])
xmdl[indx:] = np.polyval(res2, tcen[indx:])
# Calculate BIC:
s2 = BIC(xcen, xmdl, 4)
else:
s2 = Inf
elif jumptype == 'multiplicative':
# Constant model:
correction[0] = level1_const / level2_const
if isfinite(correction[0]) and correction[0] > 0:
# Correct data:
xcen2 = deepcopy(xcen) # take a deep copy, such that corrections doesn't affect xcen
xcen2[indx:] *= correction[0]
# Calculate model:
xmdl[:] = level1_const
# Calculate BIC:
s1 = BIC(xcen2, xmdl, 2)
else:
s1 = Inf
# Linear model:
correction[1] = level1_linear / level2_linear
if isfinite(correction[1]) and correction[1] > 0:
# Correct data:
xcen2 = deepcopy(xcen) # take a deep copy, such that corrections doesn't affect xcen
xcen2[indx:] *= correction[1]
# Calculate model:
xmdl[:indx] = np.polyval(res1, tcen[:indx])
xmdl[indx:] = np.polyval(res2, tcen[indx:]) * correction[1]
# Calculate BIC:
s2 = BIC(xcen2, xmdl, 4)
else:
s2 = Inf
else:
raise ValueError('Unknown jump type')
# Apply correction to entire timeseries if the standard deviation improves:
if jumpforce:
i = np.argmin([s1, s2]) + 1
else:
# Calculate BIC of uncorrected central part:
s0 = BIC(xcen, nanmedian(xcen), 1)
i = np.argmin([s0, s1, s2])
logger.debug(i)
if i != 0: # Do not correct if unaltered data gives the best
# Apply the best correction to everything to the right of the jump:
if jumptype == 'additive':
x[kj:] += correction[i-1]
else:
x[kj:] *= correction[i-1]
#corrections[kj] = correction[i-1]
# Set the flags, if required:
if return_flags:
flag_jumps[k] = True
if jumptype == 'additive':
if i == 1:
flag_jumps2[kj] |= 2 # constant additive
elif i == 2:
flag_jumps2[kj] |= 4 # linear additive
elif jumptype == 'multiplicative':
if i == 1:
flag_jumps2[kj] |= 256 # constant multiplicative
elif i == 2:
flag_jumps2[kj] |= 512 # linear multiplicative
if return_flags:
return x, flag_jumps, flag_jumps2
else:
return x
#--------------------------------------------------------------------------------------------------
[docs]
def filter_flags(t, x, quality, quality_remove=TESSQualityFlags.DEFAULT_BITMASK, return_flags=False):
"""
Filter out flagged data from quality column.
Returns new flux vector with bad datapoints removed (set to NaN) and a vector
with flagged jump time postions. This vector can later be passed into :py:func:`filter_jumps`.
Parameters:
t (ndarray): Time vector (days).
x (ndarray): Flux vector.
quality (ndarray): Quality flags (integers).
quality_remove (integer, optional): Flags that corresponds to bad data points.
return_flags (boolean): Also return flags of removed points.
Returns:
tuple:
- ndarray: Flux vector filterd for bad data points.
- list: jumps
- ndarray: flag_removed
"""
N = len(t)
# Store jump positions:
jumps = array([], dtype={'names': ['time','type','force'], 'formats': ['float64','a1','bool']})
# Attitude tweaks:
# These often have times set as undefined in the files
# so we need to find the next defined timestamp
indx = np.where(quality & TESSQualityFlags.AttitudeTweak != 0)[0]
for k in indx:
# Find the first valid timestamp after jump:
k_next = k
while k_next < N:
if isfinite(t[k_next]):
jumps = append(jumps, {'time': t[k_next], 'type': 'additive', 'force': False})
break
k_next += 1
# Detected jumps:
indx = (quality & TESSQualityFlags.SensitivityDropout != 0)
if indx.any():
# Move detected indicies one to the right
# as we need the first timestamp after the jump:
indx = np.roll(indx, 1)
indx[0] = False
jumps = append(jumps, [{'time': t[i], 'type': 'additive', 'force': False} for i in np.where(indx)[0]])
# Remove points flagged as bad data:
flag_removed = (quality & quality_remove != 0) | (~isfinite(x))
x[flag_removed] = NaN
# Return results:
if return_flags:
return x, jumps, flag_removed
else:
return x, jumps
#--------------------------------------------------------------------------------------------------
def _filter_single_phase(phase, x, width, dphase):
"""
Function that takes single phase-curve and returns smoothed version.
Parameters:
phase (ndarray):
x (ndarray):
width (float):
dphase (float):
Returns:
ndarray: Smoothed phase curve.
"""
phase_smooth = moving_nanmedian_cyclic(phase, x, width, dt=dphase)
phase_smooth = smooth_cyclic(phase_smooth, width/dphase)
return phase_smooth
#--------------------------------------------------------------------------------------------------
[docs]
def filter_phase(t, x, Plist, smooth_factor=1000):
"""
Filter out specific periods by smoothing the phase-curve.
Parameters:
t (ndarray): Time vector (days).
x (ndarray): Flux vector.
P (list): List of periods to remove.
smooth_factor (float, optional): Factor of phase to use as smooth width.
Returns:
Filter flux vector that can be removed from timeseries.
Note:
Does not require time to be sorted.
Can handle NaN in flux vector.
"""
# Prepare arrays:
Plist = np.atleast_1d(Plist) # Hack to handle 0-dim input
Np = len(Plist)
Nt = len(t)
phase = zeros((Np,Nt), dtype='float64')
indx = zeros((Np,Nt), dtype='int')
indx_inv = zeros((Np,Nt), dtype='int')
phase_tot = zeros(Nt, dtype='float64')
phase_smooth_t = zeros((Np,Nt), dtype='float64')
dphase = zeros(Np, dtype='float64')
# Loop through periods to be removed:
for k in range(Np):
# Calculate the phase and sort it:
phase[k] = mod(t, Plist[k])
indx[k] = argsort(phase[k])
indx_inv[k] = argsort(indx[k])
dphase[k] = median(diff( phase[k,indx[k]] ))
# Calculate smooth version of the phase curve:
phase_smooth = _filter_single_phase(phase[k,indx[k]], x[indx[k]]-phase_tot[indx[k]], Plist[k]/smooth_factor, dphase[k])
# Un-sort phase_smoooth back to time-sorted order:
phase_smooth_t[k] = phase_smooth[indx_inv[k]]
# Add to the total phase filter:
phase_tot += phase_smooth_t[k,:]
# If removing multiple periods perform iterative procedure where
# phase curves are added and removed to avoid cross-talk between periods:
if k != 0:
for j in range(k):
# Add the transit back into to the timeseries (by subtracting it from the filter):
phase_tot -= phase_smooth_t[j,:]
# Re-calculate the phase curve of the transit:
phase_smooth = _filter_single_phase(phase[j,indx[j]], x[indx[j]]-phase_tot[indx[j]], Plist[j]/smooth_factor, dphase[j])
phase_smooth_t[j] = phase_smooth[indx_inv[j]]
# Remove the transit again:
phase_tot += phase_smooth_t[j,:]
# Make plots of phase curves:
if _output_folder is not None:
# Find the point on the smoothed curve that deviates the most from zero:
imax = nanargmax(np.abs(phase_smooth_t), axis=1)
s = nanstd(x)
fig = plt.figure(num='phasecurve')
fig.subplots_adjust(hspace=0.05)
for k, P in enumerate(Plist):
# Plot phasecurve for this period:
ax = plt.subplot(Np, 1, k+1)
ax.plot(phase[k]/P, x, 'k.', markersize=2) # No need to sort if we only plot points
ax.plot(phase[k,indx[k]]/P, phase_smooth_t[k,indx[k]], 'r-')
ax.axvline(phase[k,imax[k]]/P, color='b', linestyle='--') # Line indicating the (likely) planet transit
ax.set_xlim(0, 1)
ax.set_ylim(-6*s, 6*s)
ax.text(0.02, 0.97, f'P = {P:f} d', horizontalalignment='left', verticalalignment='top', transform=ax.transAxes, backgroundcolor='w', color='k')
if k != Np-1: plt.setp(ax.get_xticklabels(), visible=False)
ax.set_xlabel('Phase')
fig.text(0.03, 0.5, 'Flux (counts/s)', ha='center', va='center', rotation='vertical', transform=fig.transFigure)
if _output_format != 'native':
fig.savefig(os.path.join(_output_folder, _output_prefix+'phasecurve.'+_output_format), format=_output_format, bbox_inches='tight')
plt.close(fig)
# Return the total time-sorted phase curve:
return phase_tot
#--------------------------------------------------------------------------------------------------
#--------------------------------------------------------------------------------------------------
[docs]
def spline_set_knots(x, num_knots, min_points_per_knot=3):
knots = np.linspace(nanmin(x), nanmax(x), num_knots+2)
hist, bin_edges = np.histogram(x, bins=knots, normed=False)
print(knots)
print(hist)
print(bin_edges)
indx = hist > min_points_per_knot
bin_edges = bin_edges[indx]
hist, bin_egdes = np.histogram(x, bins=bin_edges, normed=False)
print(knots)
print(hist)
print(bin_edges)
# Remove knots if there is not at least 3 points between them:
newknots = array([], dtype='float64')
for i in range(len(knots)-1):
indx_data_between_knots = (knots[i] < x) & (x < knots[i+1])
if sum(indx_data_between_knots) > min_points_per_knot:
newknots = append(newknots, knots[i])
knots = newknots
#--------------------------------------------------------------------------------------------------
[docs]
def filter_position_1d(time, flux, star_movement, timescale_position_smooth=None, dt=None):
"""Filter the lightcurve for correlations in the stars position on the CCD."""
# Check input:
if len(time) != len(flux):
raise ValueError("TIME and FLUX should have the same number of elements.")
if timescale_position_smooth is not None and dt is None:
dt = median(diff(time))
# Settings:
# num_knots = 15
# min_points_per_knot = 3
# spline_degree = 2
# sigma_clip_spline = 4.0
# Build up xpos chunk by chunk of the timeseries:
xpos = np.empty_like(time, dtype='float64')
for chk, chunk in enumerate(star_movement['chunks']):
# Extract needed information:
cl = star_movement['curvelength'][chk] # Sorted in position
indx_possort = star_movement['indx_possort'][chk]
indx_timesort = star_movement['indx_timesort'][chk]
# Create smooth curve as flux as a function of curvelength:
# The resulting "xp" will be sorted by position
fl = flux[chunk][indx_possort]
"""indx_finite = isfinite(cl) & isfinite(fl)
knots = spline_set_knots(cl[indx_finite], num_knots)
# Create the fixed knots for the spline function:
knots = np.linspace(nanmin(cl[indx_finite]), nanmax(cl[indx_finite]), num_knots+2)[1:-2]
# Remove knots if there is not at least 3 points between them:
newknots = array([], dtype='float64')
for i in range(len(knots)-1):
indx_data_between_knots = (knots[i] < cl[indx_finite]) & (cl[indx_finite] < knots[i+1])
if sum(indx_data_between_knots) > min_points_per_knot:
newknots = append(newknots, knots[i])
knots = newknots
# Do a spline where all points are given the same weight:
spline = LSQUnivariateSpline(cl[indx_finite], fl[indx_finite], knots, w=None, k=spline_degree)
# Begin iterating so we can change the weights:
for iterations in range(2):
# Calculate weight of points based of their distance to
# the previously calculated spline:
d = np.abs( fl[indx_finite] - spline(cl[indx_finite]) )
s = mad_to_sigma * median(d)
w = 0.5*(np.sign(sigma_clip_spline - d/s) + 1) # Heaviside cutoff-function
# Recalculate the spline, using the weights:
spline = LSQUnivariateSpline(cl[indx_finite], fl[indx_finite], knots, w=w, k=spline_degree)
# Evaluate the spline function at the curvelengths of the datapoints:
# The spline function will return NaN if passed a NaN
xp = spline(cl)
"""
lowess_frac = 0.1 / (nanmax(cl[np.isfinite(fl)]) - nanmin(cl[np.isfinite(fl)]))
xp = lowess(fl, cl, frac=lowess_frac, it=3, is_sorted=True, return_sorted=False)
# Sort back into time-sorting and put NaN's back,
# then low-pass filter the result:
if timescale_position_smooth is None:
xpos[chunk] = xp[indx_timesort]
else:
xpos[chunk] = moving_nanmedian(time[chunk], xp[indx_timesort], timescale_position_smooth, dt=dt)
# Return the final time-sorted series:
return xpos
#--------------------------------------------------------------------------------------------------
[docs]
def filter(t, x, quality=None, position=None, P=None, jumps=None, timescale_long=3.0,
timescale_short=1/24, sigma_clip=4.5, scale_clip=5.0, scale_width=1.0,
phase_smooth_factor=1000, transit_model=None, it=3):
"""
Main filter function.
Parameters:
t (ndarray): Time vector (days).
x (ndarray): Flux vector.
quality (ndarray, None): Quality vector (bit-flags); default=None.
position (ndarray, None): Centroid positions of star on CCD as two column list; default=None.
P (ndarray): Known planetary period (days); default=None.
jumps (list): List of known jumps in the flux (timestamp in days); default=None.
timescale_long (float): Timescale of long filter in days; default=3.
timescale_short (float): Timescale of short filter in days; default=1/24.
sigma_clip (float): Sigma-clip threshold; default=4.5.
scale_clip (float): Scale at which to switch between long and short filters; default=5.
scale_width (float): Width of transition region between filters; default=1.
phase_smooth_factor (float): Fraction of period to smooth phase curce with; default=1000.
transit_model (ndarray): Full transit model to be used instead of smoothed phase curve; default=None.
it (integer): Number of iterations between different filters. Default=3.
Returns:
tnew - New time vector with the same length as the input vectors.
xnew - New flux vector with the same length as the input vectors.
sigma - Vector of estmated errors on measurements.
flags - Vector of KASOC flags.
filt - Vector with the final filter applied (after jump removal).
turnover - Turnover function with weights to long and short filter.
"""
# Basic check of input:
N = len(t)
if len(x) != N:
raise ValueError("TIME and DATA does not have the same length")
if transit_model is not None and len(transit_model) != N:
raise ValueError("TRANSIT_MODEL is wrong length")
if quality is not None and len(quality) != N:
raise ValueError("QUALITY is wrong length")
if position is not None:
if not isinstance(position, dict):
position = {'pixels': position, 'break': np.array([], dtype='float64')}
if position['pixels'].shape != (N, 2):
raise ValueError("POSITION must have the shape (N,2)")
if it < 1:
raise ValueError("IT must be at least one.")
# Get the logger to use for printing messages:
logger = logging.getLogger(__name__)
# Sort the data in ascending order of time (This is needed for median filters to work)
indx_sorttime = argsort(t)
x = x[indx_sorttime] # data sorted after time
t = t[indx_sorttime] # sorted time
if quality is not None:
quality = quality[indx_sorttime] # sorted quality
if position is not None:
position['pixels'] = position['pixels'][indx_sorttime, :] # sorted position
# If not correcting position and transits, don't iterate:
if position is None and transit_model is None and P is None:
it = 1
# Find median cadence:
dt = median(np.diff(t))
# Use the quality values to filter out bad values:
if quality is not None:
x, tmpJumps, flag_removed = filter_flags(t, x, quality, return_flags=True)
if len(tmpJumps) > 0:
if jumps is None:
jumps = tmpJumps
else:
jumps = append(jumps, tmpJumps)
else:
flag_removed = ~isfinite(x)
# Remove jumps:
if jumps is not None:
logger.info('Removing jumps...')
x, jumps_flag, flag_jumps2 = remove_jumps(t, x, jumps, return_flags=True)
# Fill gaps in timeseries with NaN
# "ori" is a flag so xg[ori] will retrive the original points
logger.info('Filling gaps...')
tg, xg, ori = gap_fill(t, x, timescale_long)
Ng = len(tg)
# Calculate wide median filter and possibly filter out
# flux changes correlated with stars position on CCD:
if position is not None:
logger.info('Extracting position information...')
# Remove points that have been flagged as bad from positions:
position['pixels'][flag_removed, :] = np.NaN
# Fill the gaps in the position timeseries with NaNs:
posg = np.full((Ng, 2), np.NaN, dtype='float64')
posg[ori, :] = position['pixels']
position['pixels'] = posg
# Run subroutine which determines xlong and xpos using the positions:
flag_bad_pos, star_movement = extract_star_movement_1d(tg, xg, position, dt=dt)
# Number of columns to plot on the "decorrelation" plot:
# NOTE: Not "+2" as Nchunks is the number of breaks and not the number of chunks
ncols = star_movement['Nchunks'] + 1
else:
flag_bad_pos = np.zeros(Ng, dtype='bool')
ncols = 1
flux_ylim = np.percentile(x[isfinite(x)], [0.25, 99.75])
# Prepare the "decorrelation" figure:
ax1 = ax2 = None
figsize = [8*1.7, 6*1.7]
figsize[0] = figsize[0] * max(ncols/3, 1)
figsize[1] = figsize[1] * max(it/3, 1)
fig = plt.figure(num='Decorrelation', figsize=figsize)
fig.subplots_adjust(hspace=0.05)
# Repeat the determination of xlong and xpos to better disentangle them:
xpos = np.zeros(Ng, dtype='float64')
xtransit = np.zeros(Ng, dtype='float64')
xpos[flag_bad_pos] = np.NaN # Set points found to be bad to NaN so they wont contribute in the following
for i in range(it):
logger.info("Running %d iteration:", i+1)
# Create long moving median, by removing previously found xpos and xtransit:
logger.info(' Calculating long moving median...')
xinp = xg - xpos - xtransit
xlong = moving_nanmedian(tg, xinp, timescale_long, dt=dt)
xlong[flag_bad_pos] = NaN
# Create first column of plot with determination of xlong:
ax1 = fig.add_subplot(it, ncols, ncols*i+1, sharex=ax1)
ax1.scatter(tg, xinp, color='k', s=1, alpha=0.5)
ax1.plot(tg, xlong, 'g-')
ax1.set_xlim(tg[0], tg[-1])
ax1.set_ylim(flux_ylim)
ax1.set_ylabel(r'Flux (e$^-$/s)')
plt.yticks(fontsize=10)
ax1.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useOffset=False))
if i == 0:
ax1.set_title(r'$x_\mathrm{long}$')
if i == it-1:
ax1.set_xlabel('Time (days)', fontsize=10)
plt.xticks(fontsize=10)
else:
plt.setp(ax1.get_xticklabels(), visible=False)
# Filter the timeseries for the star movement:
if position is not None:
logger.info(' Filtering star movements...')
xinp = xg - xlong - xtransit
xpos = filter_position_1d(tg, xinp, star_movement, dt=dt)
for kc,chunk in enumerate(star_movement['chunks']):
indx_possort = star_movement['indx_possort'][kc]
curvelength_chunk = star_movement['curvelength'][kc]
ax2 = fig.add_subplot(it, ncols, ncols*i+kc+2)
ax2.scatter(curvelength_chunk, xinp[chunk][indx_possort], color='k', s=1, alpha=0.5)
ax2.plot(curvelength_chunk, xpos[chunk][indx_possort], 'r-')
plt.yticks(fontsize=10)
if i == 0: ax2.set_title(f'Position-flux #{kc+1:d}')
if i == it-1:
ax2.set_xlabel('Curve length (pixels)', fontsize=10)
plt.xticks(fontsize=10)
else:
plt.setp(ax2.get_xticklabels(), visible=False)
# The next column with xpos as a function of time:
ax3 = fig.add_subplot(it, ncols, ncols*(i+1), sharex=ax1)
ax3.scatter(tg, xinp, color='k', s=1, alpha=0.5)
ax3.plot(tg, xpos, 'r-')
ax3.set_xlim(tg[0], tg[-1])
plt.yticks(fontsize=10)
ax3.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useOffset=False))
if i == 0: ax3.set_title(r'$x_\mathrm{pos}$')
if i == it-1:
ax3.set_xlabel('Time (days)', fontsize=10)
plt.xticks(fontsize=10)
else:
plt.setp(ax3.get_xticklabels(), visible=False)
# Calculate phase-curve, if periods are provided:
if P is not None:
logger.info(" Calculating phase curve...")
xtransit = filter_phase(tg, xg-xlong-xpos, P, smooth_factor=phase_smooth_factor)
elif transit_model is not None:
# Create filter using transit model
# Do it in this way since transit model is relative with respect to 1
# Fill gaps of transit model the same way as the data:
xtransit = np.ones(Ng)
xtransit[ori] = transit_model
filt = (xlong+xpos) * xtransit
xtransit = filt - (xlong+xpos)
# Save the figure:
if _output_folder is not None:
fig.savefig(os.path.join(_output_folder, _output_prefix+'decorrelation.'+_output_format), format=_output_format, bbox_inches='tight')
if _output_format != 'native':
plt.close(fig)
# Make sure we have removed the bad datapoints:
xg[flag_bad_pos] = np.NaN
# Construct the final filter:
filt = xlong + xtransit + xpos
# Run the old KASOC filter to remove any potential unknown transits and sharp features:
if timescale_short is not None:
# Make a switch for long cadence data that puts a lower limit on timescale_short of 7 points (3.5 hours for LC)
if timescale_short < 7*dt:
logger.warning("WARNING: timescale_short is less than 7 points wide!")
# Smooth the data with short moving median:
logger.info("Calculating short moving median...")
xshort = moving_nanmedian(tg, xg-filt, timescale_short, dt=dt)
xshort_tilde = deepcopy(xshort)
xshort = filt + xshort
# Create timeseries of the long filter, divided by the short filter:
w4 = filt/xshort - 1
# Smooth the timeseries using a very short filter to remove any very high frequency noise:
w4_smooth_width = int(timescale_short/dt)
w4 = smooth(w4, w4_smooth_width)
w4 = smooth(w4, w4_smooth_width)
w4 = smooth(w4, w4_smooth_width)
# Calculate moving standard deviation of timeseries
# in units of sigmas:
w5 = moving_nanmedian(tg, np.abs(w4), timescale_short)
snr = w5/nanmedian(w5)
# Create "flag"/weight indicating how much of the short filter and the long filter should
# be used at each timestep. Is a number between 0 (long filter) and 1 (short filter).
if scale_width > 0:
with np.errstate(invalid='ignore'):
turnover = norm.cdf(snr, scale_clip, scale_width)
else:
# For zero width, use the Heaviside function:
turnover = 0.5*(np.sign(snr-scale_clip) + 1)
# Create final filter as weighted mean of the long and short filters:
filt = (1-turnover)*filt + turnover*xshort
# Plot the derived filter compoments:
if _output_folder is not None:
fig = plt.figure(num='turnover')
fig.subplots_adjust(hspace=0.05)
ax1 = plt.subplot(211)
ax1.axhspan(scale_clip-scale_width, scale_clip+scale_width, facecolor='0.5', edgecolor=None, alpha=0.5)
ax1.plot(tg, snr, 'b-')
ax1.set_ylabel(r'$\sigma_w$', fontsize=10)
ax1.set_title('Filter turnover function', fontsize=12)
ax1.set_xlim(t[0], t[-1])
ax1.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useOffset=False))
plt.yticks(fontsize=10)
plt.setp(ax1.get_xticklabels(), visible=False)
# Axes showing the derived weights:
ax2 = plt.subplot(212, sharex=ax1)
ax2.plot(tg, turnover, 'b-')
ax2.set_ylim(0, 1)
ax2.set_ylabel('$c$', fontsize=10)
ax2.set_xlabel('Time', fontsize=10)
ax2.set_xlim(t[0], t[-1])
ax2.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useOffset=False))
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
if _output_format != 'native':
fig.savefig(os.path.join(_output_folder, _output_prefix+'turnover.'+_output_format), format=_output_format, bbox_inches='tight')
plt.close(fig)
else:
xshort = np.zeros(Ng, dtype='float64')
xshort_tilde = xshort
turnover = np.zeros(Ng, dtype='float64')
# Flag with significant sharp and negative features (transits?):
with np.errstate(invalid='ignore'):
flag_transit = (turnover > 0.5) & (xshort < xlong+xpos+xtransit)
# Plot the final filter:
if _output_folder is not None:
mask_long = isfinite(xlong)
mask_short = isfinite(xshort)
mask_filt = isfinite(filt)
fig = plt.figure(num='components')
ax = fig.add_subplot(111)
h1 = plt.scatter(t, x, color='k', s=2)
h2, = plt.plot(tg[mask_long], xlong[mask_long], 'b-')
h5, = plt.plot(tg, xlong+xpos+xtransit, 'y-')
h3, = plt.plot(tg[mask_short], xshort[mask_short], 'g-')
h4, = plt.plot(tg[mask_filt], filt[mask_filt], 'r-')
ax.plot(tg[flag_transit], xg[flag_transit], 'go', markersize=2)
plt.legend([h1, h2, h5, h3, h4], ['Data', r'$x_{\rm long}$', r'$x_{\rm pos}+x_{\rm transit}$', r'$x_{\rm short}$', 'Final filter'], fontsize=8, ncol=2, loc='best')
ax.set_xlabel('Time', fontsize=10)
ax.set_ylabel('Flux', fontsize=10)
ax.set_xlim(t[0], t[-1])
ax.set_ylim(flux_ylim)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
ax.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useOffset=False))
if _output_format != 'native':
fig.savefig(os.path.join(_output_folder, _output_prefix+'components.'+_output_format), format=_output_format, bbox_inches='tight')
plt.close(fig)
# Apply final filter and convert to ppm:
xg = 1e6*(xg/filt - 1)
# Remove outliers using sigma-clipping:
# The mean is already taken out, so we only
# need to calculate the deviation from zero.
logger.info("Calculating sigma...")
flag_bad = array([False]*Ng, dtype='bool')
absx = np.abs(xg)
sigma = moving_nanmedian(tg, absx, timescale_long, dt=dt)
if sigma_clip is not None:
with np.errstate(invalid='ignore'):
sigma_clip = mad_to_sigma * sigma_clip # less expensive to convert sigma_clip than sigma vector
# 9. Estimate the point-to-point error from final timeseries:
# We need to re-do it because bad data points might have biases
# the previously calculculated sigmas
flag_bad = (absx > sigma_clip*sigma)
#############################
while True:
flag_rem = (absx > sigma_clip*sigma)
if flag_rem.any():
# Remove bad data points from timeseries:
flag_bad[flag_rem] = True
absx[flag_rem] = np.NaN
sigma = moving_nanmedian(tg, absx, timescale_long, dt=dt)
else:
break
#############################
# Bad data points should also be NaN:
xg[flag_bad] = np.NaN
# Convert to proper sigma indsted of MAD:
indx = ~isfinite(xg)
sigma[indx] = np.NaN
sigma = mad_to_sigma * smooth(sigma, int(timescale_long/dt))
sigma[indx] = np.NaN
# Return results:
# Remove the gap-filled data again:
x = xg[ori]
sigma = sigma[ori]
filt = filt[ori]
flag_bad = flag_bad[ori]
turnover = turnover[ori]
flag_transit = flag_transit[ori]
flag_bad_pos = flag_bad_pos[ori]
xlong = xlong[ori]
xpos = xpos[ori]
xtransit = xtransit[ori]
# Return this instead of xshort, so the filter is easier to
# "disacemble" into the components, since this means that the
# filter can be written as:
# filter = xlong + xpos + xtransit + xshort
xshort = turnover * xshort_tilde[ori]
# Create KASOC flag vector:
quality_flags = np.zeros(N, dtype='int64')
quality_flags[flag_removed] |= 1
if jumps is not None:
quality_flags |= flag_jumps2 # Sets 2+4+256+512
quality_flags[flag_bad] |= 8
quality_flags[flag_transit] |= 16
if position is not None:
quality_flags[flag_bad_pos] |= 32
# Find the indicies of points just after position breaks:
if len(star_movement['tbreaks']) >= 3:
ibreak = searchsorted(t, star_movement['tbreaks'][1:-1])
quality_flags[ibreak] |= 64
# Check that the extracted errorbars make sense:
with np.errstate(invalid='ignore'):
indx_invalid_sigma = (sigma < 1e-8)
#indx_invalid_sigma = (sigma < 0.01*nanmedian(sigma))
#nms = nanmedian(sigma)
#fig = plt.figure()
#ax = fig_addsubplot(111)
#ax.plot(t, sigma, 'b-')
#ax.axhline(0.01*nms, color='k', ls='--')
#ax.axhline(0.05*nms, color='k', ls='--')
#ax.set_ylabel(r'$\sigma$ (ppm)', fontsize=10)
#ax.set_xlabel('Time', fontsize=10)
#plt.close(fig)
if np.any(indx_invalid_sigma):
# Generate a warning message:
number_invalid_sigma = np.sum(indx_invalid_sigma)
try:
logger.warning("Invalid SIGMAs extracted (%d points = %.2f%%). Timescales should maybe be adjusted.", number_invalid_sigma, 100*number_invalid_sigma/N)
warnings.warn("Invalid SIGMAs extracted", InvalidSigmasWarning)
except IOError:
print("Something went wrong in the logging of invalid sigmas")
# Set the timeseries to NaN where sigmas are invalid,
# and add a flag (128) to the quality-flags:
x[indx_invalid_sigma] = np.NaN
sigma[indx_invalid_sigma] = np.NaN
quality_flags[indx_invalid_sigma] |= 128
# Plot the final filtered timeseries:
if _output_folder is not None:
fig = plt.figure(num='final filter')
fig.subplots_adjust(hspace=0.05)
ax1 = plt.subplot(211)
ax1.plot(t, x, 'b.', markersize=2)
ax1.set_xlim(t[0], t[-1])
ax1.set_ylabel('Relative flux (ppm)', fontsize=10)
ax1.set_title("Final timeseries", fontsize=12)
ax1.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useOffset=False))
plt.setp(ax1.get_xticklabels(), visible=False)
plt.yticks(fontsize=10)
ax2 = plt.subplot(212, sharex=ax1)
ax2.plot(t, sigma, 'b-')
ax2.set_ylabel(r'$\sigma$ (ppm)', fontsize=10)
ax2.set_xlabel('Time', fontsize=10)
ax2.set_xlim(t[0], t[-1])
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
ax2.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter(useOffset=False))
if _output_format != 'native':
fig.savefig(os.path.join(_output_folder, _output_prefix+'final.'+_output_format), format=_output_format, bbox_inches='tight')
plt.close(fig)
# Return everything needed:
return t, x, sigma, quality_flags, filt, turnover, xlong, xpos, xtransit, xshort