import logging
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d
from caat import SN
from caat.utils import FILT_TEL_CONVERSION, bin_spec, query_svo_service
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
[docs]
class DataCube:
"""
Class containing routines to handle pre-processing of data.
The purpose of this class is to process and store photometry
of a SN object in a pandas dataframe, containing shifted,
logarithmic, linear, magnitude, and flux values, as well as
metadata such as filter effective wavelengths and nondetections.
This enables efficient loading, sorting, and mathematical operations
on the data during the Gaussian Process Regression fitting routines.
Runs SN methods such as `load_<>_data`, `shift_to_max`,
and `convert_to_fluxes` to initialize data set.
Also optionally performs iterative warping on SED at each
epoch to better match observed photometry.
Constructs a `pandas` dataframe with this information to be
saved, read, and manipulated as part of the fitting routine.
"""
def __init__(
self,
sn: SN = None,
name: str | None = None,
data: dict | None = None,
):
"""
Initialize a DataCube object.
Store the photoemtry of a SN object in this `DataCube` instance.
The SN can be passed directly, by name, or with a dictionary of data.
If no data is found, this method attempts to load the data using the
SN object methods.
Args:
sn (SN, optional): The SN object. Defaults to None.
name (str | None, optional): The name of a SN object, to load,
if one exists. Defaults to None.
data (dict | None, optional): A dictionary of data to load with
the SN object. Defaults to None.
"""
if sn:
self.sn = sn
else:
sn = SN(name=name, data=data)
self.sn = sn
if not self.sn.data:
self.sn.load_json_data()
self.sn.load_swift_data()
self.sn.correct_for_galactic_extinction()
for filt in list(self.sn.data.keys()):
self.sn.shift_to_max(filt)
if filt not in self.sn.wle.keys():
del self.sn.data[filt]
self.sn.convert_all_mags_to_fluxes()
[docs]
def construct_cube(self):
"""
Construct a data cube using the stored photometry.
A data cube is a pandas dataframe which stores the
photometry of a SN object, in both magnitude and flux
space as well as both linear and log-transformed values.
This method loads all photometry from the SN data dictionaries,
redshifts it, shifts it relative to the light curve peak, and stores
all associated values.
"""
if not any(
[
filt
for filt in list(
{
filt
for filt in list(self.sn.data.keys())
+ list(self.sn.shifted_data.keys())
}
)
]
) or (not self.sn.data or not self.sn.shifted_data):
cube = np.asarray([[], [], [], [], [], [], [], [], [], [], [], [], [], []])
else:
cube = np.array(
[
np.hstack(
[
[d["mjd"] for d in self.sn.data[filt]]
for filt in self.sn.data.keys()
]
),
np.hstack(
[
[d["mjd"] for d in self.sn.shifted_data[filt]]
for filt in self.sn.shifted_data.keys()
]
),
np.hstack(
[
np.repeat([filt], len(self.sn.data[filt]))
for filt in self.sn.data.keys()
]
),
np.hstack(
[
np.repeat([filt], len(self.sn.shifted_data[filt]))
for filt in self.sn.shifted_data.keys()
]
),
np.hstack(
[
np.repeat(
[self.sn.wle[filt] * (1 + self.sn.info.get("z", 0.0))],
len(self.sn.data[filt]),
)
for filt in self.sn.data.keys()
]
),
np.hstack(
[
np.repeat(
[self.sn.wle[filt] * (1 + self.sn.info.get("z", 0.0))],
len(self.sn.shifted_data[filt]),
)
for filt in self.sn.shifted_data.keys()
]
),
np.hstack(
[
[d["flux"] for d in self.sn.data[filt]]
for filt in self.sn.data.keys()
]
),
np.hstack(
[
[d["flux"] for d in self.sn.shifted_data[filt]]
for filt in self.sn.shifted_data.keys()
]
),
np.hstack(
[
[d["fluxerr"] * d["flux"] for d in self.sn.data[filt]]
for filt in self.sn.data.keys()
]
),
np.hstack(
[
[d["fluxerr"] for d in self.sn.shifted_data[filt]]
for filt in self.sn.shifted_data.keys()
]
),
np.hstack(
[
[d["mag"] for d in self.sn.data[filt]]
for filt in self.sn.data.keys()
]
),
np.hstack(
[
[d["shiftedmag"] for d in self.sn.shifted_data[filt]]
for filt in self.sn.shifted_data.keys()
]
),
np.hstack(
[
[d["err"] for d in self.sn.data[filt]]
for filt in self.sn.data.keys()
]
),
np.hstack(
[
[d.get("nondetection", False) for d in self.sn.data[filt]]
for filt in self.sn.data.keys()
]
),
],
dtype=object,
)
if len(cube.shape) != 2:
logger.warning(f"WARNING: Construct cube failed for {self.sn.name}")
cube = np.asarray([[], [], [], [], [], [], [], [], [], [], [], [], [], []])
self.cube = pd.DataFrame(
data=cube.T,
columns=[
"MJD",
"Phase",
"Filter",
"ShiftedFilter",
"Wavelength",
"ShiftedWavelength",
"Flux",
"ShiftedFlux",
"Fluxerr",
"ShiftedFluxerr",
"Mag",
"ShiftedMag",
"Magerr",
"Nondetection",
],
).dropna()
[docs]
def deconstruct_cube(self):
"""
Takes as input a data cube and turns it into a dictionary
of photometry to be used in GP fitting
Assumes a data cube of size (n, 5).
NOTE: This method is not currently implemented.
"""
# TODO: Make this compatible with change to pandas df
raise NotImplementedError
# data = {}
# for i in range(len(self.cube[0,:])):
# data.setdefault(self.cube[4][i], []).append(
# {
# 'mjd': self.cube[0][i],
# 'wle': self.cube[1][i],
# 'flux': np.log10(self.cube[2][i]),
# 'fluxerr': self.cube[3][i]/10**(self.cube[2][i]),
# 'mag': self.cube[5][i],
# 'err': self.cube[6][i],
# 'nondetection': self.cube[7][i]
# }
# )
# if self.shift:
# self.sn.shifted_data = data
# else:
# self.sn.data = data
[docs]
def plot_cube(self):
"""
Plots the photometry stored in the DataCube.
Plots in flux space by default.
"""
if not hasattr(self, "cube"):
self.construct_cube()
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.errorbar(
self.cube["Phase"],
self.cube["Wavelength"],
self.cube["Flux"],
zerr=self.cube["Fluxerr"],
fmt="o",
)
ax.set_xlabel("Phase Grid")
ax.set_ylabel("Wavelengths")
ax.set_zlabel("Fluxes")
plt.tight_layout()
plt.show()
[docs]
def measure_flux_in_filter(
self,
niter: int = 100,
convergence_threshold: float = 1.1,
plot: bool = False,
verbose: bool = False,
save: bool = False,
overwrite: bool = False,
):
"""
Construct a "warped" data cube using the input photometry.
A "warped" data cube is one that has been iteratively mangled until
the interpolated SED at a given epoch convolved with the filter functions
of the filters at that epoch match the input flux values. This process
shifts the effective wavelengths of the filters until the convolved flux
(aka synthetic flux) matches the observed values, within some convergence
threshold. This step is necessary to ensure that the shape of the input SED
is true to the actual SN SED.
Args:
niter (int, optional): The number of iterations to iteratively
warp the SED at each phase. Defaults to 100.
convergence_threshold (float, optional): The multiplicative factor
up to which the convolved flux is considered "converged" with the
input flux. Defaults to 1.1.
plot (bool, optional): Plot the SED and interpolated photometry.
Defaults to False.
verbose (bool, optional): Optionally output statistics and values
about the iterative mangling. This is to aid in debugging.
Defaults to False.
save (bool, optional): Save the final warped data cube to file.
The file path is built automatically from the SN object.
Defaults to False.
overwrite (bool, optional): Overwrite any existing warped data cube
if attempted to save. Defaults to False.
"""
if convergence_threshold <= 1.0:
logger.error("Convergence threshold must be greater than 1")
return
if (
save
and not overwrite
and os.path.exists(
os.path.join(
self.sn.base_path,
self.sn.classification,
self.sn.subtype,
self.sn.name,
self.sn.name + "_datacube_mangled.csv",
)
)
):
logger.warning(
f"Already saved mangled datacube for {self.sn.name}, skipping"
)
return
self.construct_cube()
if len(self.cube["MJD"]) == 0: # No data, so return nothing
return
trans_fns = {}
filts_to_ignore = []
for filt in np.unique(self.cube["Filter"]):
svo_filt = filt.replace("'", "").replace("s", "")
trans_wl, trans_eff = query_svo_service(
FILT_TEL_CONVERSION[svo_filt], svo_filt
)
trans_eff /= max(trans_eff)
# Get min and max wavelength for this filter, let's define it as where eff < 10%
center_of_filt = trans_wl[np.argmax(trans_eff)]
tail_wls = trans_wl[np.where((trans_eff < 0.1))[0]]
try:
min_trans_wl = np.max(
tail_wls[np.where((tail_wls < center_of_filt))[0]]
)
max_trans_wl = np.min(
tail_wls[np.where((tail_wls > center_of_filt))[0]]
)
trans_fns[filt] = {
"wl": trans_wl,
"eff": trans_eff,
"min_wl": min_trans_wl - 500.0,
"max_wl": max_trans_wl + 500.0,
}
except:
logger.warning(
f"Warning: transmission function failed for {filt}, ignoring"
)
filts_to_ignore.append(filt)
inds_to_drop = self.cube.loc[self.cube["Filter"].isin(filts_to_ignore)].index
self.cube = self.cube.drop(inds_to_drop).reset_index(drop=True)
for phase in np.arange(min(self.cube["Phase"]), max(self.cube["Phase"]), 1.0):
current_lc_inds = np.where((abs(self.cube["Phase"] - phase) <= 0.5))[0]
if (
len(current_lc_inds) > 0
and len({filt for filt in self.cube["Filter"][current_lc_inds]}) > 1
): # Have data in at least two filters at this epoch
current_lc_cube = self.cube[abs(self.cube["Phase"] - phase) <= 0.5]
bluest_wavelength = np.min(current_lc_cube["Wavelength"].values)
reddest_wavelength = np.max(current_lc_cube["Wavelength"].values)
bluest_filt = np.unique(
current_lc_cube[current_lc_cube["Wavelength"] == bluest_wavelength][
"Filter"
]
)[0]
reddest_filt = np.unique(
current_lc_cube[
current_lc_cube["Wavelength"] == reddest_wavelength
]["Filter"]
)[0]
current_lc = current_lc_cube["Flux"].values
# TODO: Eval
current_lc = np.concatenate(([0.0], current_lc, [current_lc[-1] / 2]))
current_lc_err = np.concatenate(
([0.0], current_lc_cube["Fluxerr"], [0.0])
)
current_lc_wls = np.concatenate(
(
[
trans_fns[bluest_filt]["min_wl"]
* (1 + self.sn.info.get("z", 0.0))
],
current_lc_cube["Wavelength"],
[
trans_fns[reddest_filt]["max_wl"]
* (1 + self.sn.info.get("z", 0.0))
],
)
)
wl_grid = np.linspace(current_lc_wls[0], current_lc_wls[-1], 50)
errors = np.ones(len(current_lc_inds)) * 100.0
n = 0
# central_wls = np.copy(current_lc_wls[1:-1])
measured_wls = np.copy(current_lc_wls)
measured_flux = np.copy(current_lc)
### Construct SED by interpolating over this LC
interp = interp1d(measured_wls, measured_flux, kind="linear")
binned_sed = interp(wl_grid)
for i in range(niter):
if all(errors <= convergence_threshold) or n == niter:
break
residuals = []
for j, filt in enumerate(current_lc_cube["Filter"]):
### Bin the transmission curve and SED to common resolution
binned_trans_wl, binned_trans_eff = bin_spec(
trans_fns[filt]["wl"], trans_fns[filt]["eff"], wl_grid
)
if n == 0 and plot:
plt.plot(
binned_trans_wl, binned_trans_eff * max(interp(wl_grid))
)
### Get overlap between the filter and the SED
sed_inds = np.where(
(wl_grid >= min(binned_trans_wl))
& (wl_grid <= max(binned_trans_wl))
)[0]
if len(sed_inds) > 0:
interp_filt = interp1d(binned_trans_wl, binned_trans_eff)
interp_trans_wl = np.linspace(
wl_grid[sed_inds[0]],
wl_grid[sed_inds[-1]],
len(sed_inds),
)
interp_trans_eff = interp_filt(interp_trans_wl)
flux = np.nansum(
binned_sed[sed_inds] * interp_trans_eff
) / len(interp_trans_eff)
implied_central_wl = min(
interp_trans_wl[
np.argmax(binned_sed[sed_inds] * interp_trans_eff)
],
trans_fns[filt]["max_wl"]
* (1 + self.sn.info.get("z", 0.0)),
)
implied_central_wl = max(
implied_central_wl,
trans_fns[filt]["min_wl"]
* (1 + self.sn.info.get("z", 0.0)),
)
real_flux_inds = (
np.where((current_lc_cube["Filter"] == filt))[0] + 1
)
if len(real_flux_inds) > 1:
real_flux = np.average(current_lc[real_flux_inds])
else:
real_flux = current_lc[real_flux_inds][0]
try:
error = max(flux / real_flux, real_flux / flux)
except ZeroDivisionError:
error = 100
try:
resid = flux / real_flux
except ZeroDivisionError:
resid = 100
if verbose:
logger.info(
f"Filter: {filt}, convolved flux: {flux}, measured flux: {real_flux}, error: {error}"
)
logger.info(
f"Filter: {filt}, real wavelength: {current_lc_wls[j + 1]}, warped wl: {implied_central_wl}"
)
errors[j] = error
residuals.append(resid)
measured_flux[j + 1] = flux
measured_wls[j + 1] = implied_central_wl
else:
### No overlap between SED and filter, so break
n = niter
if any(errors > convergence_threshold):
### Make a residual interpolated SED from the convolved fluxes at the
### implied wavelengths, warp the SED using this residual, and rerun the loop
if n < niter:
residual_interp = interp1d(
measured_wls,
np.concatenate(([0.0], residuals, [residuals[-1] / 2])),
)
residual = residual_interp(wl_grid)
binned_sed /= residual
n += 1
if any(errors > 1e3) or any(np.isinf(errors)):
n = niter
if n == niter and verbose:
logger.warning("Couldnt iterate to match flux!")
if plot:
plt.errorbar(
current_lc_wls,
current_lc,
yerr=current_lc_err,
fmt="o",
alpha=0.3,
)
plt.errorbar(
measured_wls, measured_flux, yerr=current_lc_err, fmt="o"
)
plt.plot(
wl_grid,
interp1d(measured_wls, measured_flux, kind="linear")(wl_grid),
)
plt.show()
if n < niter:
### Put the new effective wavelengths back into the data cube in the right spots
current_lc_cube["Wavelength"] = measured_wls[1:-1]
current_lc_cube["ShiftedWavelength"] = measured_wls[1:-1]
self.cube.update(current_lc_cube)
if verbose:
logger.info(f"Done warping SED for {self.sn.name}")
if save:
self.cube.to_csv(
os.path.join(
self.sn.base_path,
self.sn.classification,
self.sn.subtype,
self.sn.name,
self.sn.name + "_datacube_mangled.csv",
),
)