import logging
import os
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy.convolution import convolve
from scipy.interpolate import interp1d
from scipy.signal import medfilt, savgol_filter
from scipy.stats import truncnorm
from sklearn.gaussian_process import GaussianProcessRegressor
from .DataCube import DataCube
from .Diagnostics import Diagnostic
from .GP import GP
from .Kernels import Kernel
from .Plot import Plot
from .SN import SN
from .SNCollection import SNCollection, SNType
from .SNModel import SNModel, SurfaceArray
from .utils import colors
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore")
[docs]
class GP3D(GP):
"""
Class to perform GP fitting simultaneously across wavelength and phase
for a given collection of SNe.
Reads in a list of SNe to fit, as well as a collection of SNe to normalize
that sample against, along with a set of optional flags for different
fitting routines and parameters.
Each SN in the science and comparison samples should already have its photometric
data pre-processed using the routines available in the `SN` class, run within the
`DataCube` class, and read directly from the SN datacube files created by the latter.
"""
def __init__(
self,
collection: SNCollection | SNType,
kernel: Kernel,
filtlist: list,
phasemin: int,
phasemax: int,
log_transform: float,
set_to_normalize: SNCollection | SNType | None = None,
mangle_sed: bool = False,
):
"""
Initialize a GP3D object with the following arguments.
Args:
collection (SNCollection | SNType): A collection of SN objects to fit
kernel (Kernel): A Kernel object used in the Gaussian Process Regression.
filtlist (list): A list of filters to fit.
phasemin (int): The minimum phase, relative to peak brightness, to be fit.
phasemax (int): The maximum phase, relative to peak brightness, to be fit.
log_transform (float): The offset in the log transform. Must be larger than `phasemin`.
Effectively controls the light curve "stretch" in log space.
set_to_normalize (SNCollection | SNType | None, optional): The colleciton of
transients to normalize `collection` to. If `None`, will use `collection`.
Defaults to None.
mangle_sed (bool, optional): Use the calculated mangled `DataCube` for each
`SN` object in `collection` and `set_to_normalize`. If True, will exclude
`SN` objects without a mangled `DataCube`. Defaults to False.
"""
super().__init__(
collection, kernel, filtlist, phasemin, phasemax, log_transform
)
self.set_to_normalize = set_to_normalize
self.mangle_sed = mangle_sed
self._prepare_data()
def _prepare_data(self):
"""
Use the flags set in __init__ to filter the pandas dataframes for each SN
in the science and control samples
"""
for collection in [self.collection, self.set_to_normalize]:
for sn in collection.sne:
# Read the correct cube based on self.mangle_sed
if self.mangle_sed:
data_cube_filename = os.path.join(
sn.base_path,
sn.classification,
sn.subtype,
sn.name,
sn.name + "_datacube_mangled.csv",
)
else:
data_cube_filename = os.path.join(
sn.base_path,
sn.classification,
sn.subtype,
sn.name,
sn.name + "_datacube.csv",
)
if os.path.exists(data_cube_filename):
cube = pd.read_csv(data_cube_filename)
else:
# For now, we'll just construct it
datacube = DataCube(sn=sn)
datacube.construct_cube()
cube = datacube.cube
# Drop rows that are out of the phase range
inds_to_drop_phase = cube.loc[
(cube["Phase"] < self.phasemin) | (cube["Phase"] > self.phasemax)
].index
cube = cube.drop(inds_to_drop_phase).reset_index(drop=True)
# Drop nondetections farther than 2 days away from first/last detection
try:
min_phase = min(
cube.loc[cube["Nondetection"] == False]["Phase"].values
)
inds_to_drop_nondets_before_first_det = cube.loc[
(cube["Nondetection"] == True)
& (cube["Phase"] < (min_phase - 2.0))
].index
cube = cube.drop(inds_to_drop_nondets_before_first_det).reset_index(
drop=True
)
except ValueError: # no values
pass
try:
max_phase = max(
cube.loc[cube["Nondetection"] == False]["Phase"].values
)
inds_to_drop_nondets_after_last_det = cube.loc[
(cube["Nondetection"] == True)
& (cube["Phase"] > (max_phase + 2.0))
].index
cube = cube.drop(inds_to_drop_nondets_after_last_det).reset_index(
drop=True
)
except ValueError: # no values
pass
# Drop rows corresponding to filters not in the user-provided filter list
inds_to_drop_filts = cube.loc[~cube["Filter"].isin(self.filtlist)].index
cube = cube.drop(inds_to_drop_filts).reset_index(drop=True)
# Log transform the data (as a separate column)
cube["LogPhase"] = np.log(
cube["Phase"].values.astype(float) + self.log_transform
)
cube["LogWavelength"] = np.log10(
cube["Wavelength"].values.astype(float)
)
cube["LogShiftedWavelength"] = np.log10(
cube["ShiftedWavelength"].values.astype(float)
)
# Drop nondetections that are within the phase range and less constraining
# than the first or last detection in each filter
try:
min_flux_before_peak = min(
cube.loc[(cube["Nondetection"] == False) & (cube["Phase"] < 0)][
"Flux"
].values
)
inds_to_drop_nondets_before_peak = cube.loc[
(cube["Nondetection"] == True)
& (cube["Phase"] < 0)
& (cube["Flux"] > min_flux_before_peak)
].index
cube = cube.drop(inds_to_drop_nondets_before_peak).reset_index(
drop=True
)
except ValueError: # No values pre-peak, so pass
pass
try:
min_flux_after_peak = min(
cube.loc[(cube["Nondetection"] == False) & (cube["Phase"] > 0)][
"Flux"
].values
)
inds_to_drop_nondets_after_peak = cube.loc[
(cube["Nondetection"] == True)
& (cube["Phase"] > 0)
& (cube["Flux"] > min_flux_after_peak)
].index
cube = cube.drop(inds_to_drop_nondets_after_peak).reset_index(
drop=True
)
except ValueError: # No values post-peak, so pass
pass
# Drop nondetections between the first and last detection
cube_only_dets = cube.loc[cube["Nondetection"] == False]
if len(cube_only_dets["Phase"].values) > 0:
first_detection = min(cube_only_dets["Phase"].values)
last_detection = max(cube_only_dets["Phase"].values)
inds_to_drop_nondets_between_dets = cube.loc[
(cube["Phase"] > first_detection)
& (cube["Phase"] < last_detection)
& (cube["Nondetection"] == True)
].index
cube = cube.drop(inds_to_drop_nondets_between_dets).reset_index(
drop=True
)
try:
cube["MagFromPeak"] = sn.info["peak_mag"] - cube["Mag"]
sn.cube = cube
except:
pass
[docs]
@staticmethod
def interpolate_grid(grid, interp_array, filter_window=171):
"""
Function to remove NaNs by interpolating between actual measurements in a phase/wl grid
Takes as input a grid to interpolate over, an array containing values of the grid
along the interpolation dimension, as well as a filter window for smoothing
using a Savitsky-Golay filter
"""
for i, row in enumerate(grid):
notnan_inds = np.where((~np.isnan(row)))[0]
if len(notnan_inds) > 4:
interp = interp1d(interp_array[notnan_inds], row[notnan_inds], "linear")
interp_row = interp(interp_array[min(notnan_inds) : max(notnan_inds)])
savgol_l = savgol_filter(interp_row, filter_window, 3, mode="mirror")
row[min(notnan_inds) : max(notnan_inds)] = savgol_l
grid[i] = row
return grid
def _build_samples(
self,
filt,
sn_set=None,
):
"""
Builds the data set from the SN collection for a given filter
and returns, along with the phases, wls, and mags,
the uncertainty in the measurements as the standard deviation
of the photometry at each phase step
"""
phases, mags, errs, wls = super()._process_dataset(
filt,
sn_set=sn_set,
)
if len(phases) == 0:
return np.asarray([]), np.asarray([]), np.asarray([]), np.asarray([])
min_phase, max_phase = sorted(phases)[0], sorted(phases)[-1]
phase_grid = np.linspace(min_phase, max_phase, len(phases))
phase_grid_space = (max_phase - min_phase) / len(phases)
err_grid = np.ones(len(phase_grid))
for mjd in phase_grid:
ind = np.where(
(phases < mjd + phase_grid_space / 2)
& (phases > mjd - phase_grid_space / 2)
)[0]
mags_at_this_phase = mags[ind]
if len(mags_at_this_phase) > 1:
std_mag = max(np.std(mags_at_this_phase), 0.01)
elif len(mags_at_this_phase) == 1:
std_mag = errs[ind]
else:
std_mag = 0.01
err_grid[ind] = std_mag
err_grid = np.nan_to_num(err_grid, nan=0.01)
return (
phases.astype(float),
wls.astype(float),
mags.astype(float),
err_grid.astype(float),
)
def _process_dataset(self, set_to_normalize=None):
"""
Processes the data set for the GP3D object's SN collection or
(optionally) a SN set filter-by-filter and returns
dataframes of the SN collection's photometric details
or the photometric details of the SN set to normalize to
"""
### Create the template grid from the observations
if set_to_normalize is not None:
(
all_template_phases,
all_template_wls,
all_template_mags,
all_template_errs,
) = ([], [], [], [])
for filt in self.filtlist:
phases, wl_grid, mags, err_grid = self._build_samples(
filt,
sn_set=set_to_normalize,
)
all_template_phases = np.concatenate(
(all_template_phases, phases.flatten())
)
all_template_wls = np.concatenate((all_template_wls, wl_grid.flatten()))
all_template_mags = np.concatenate((all_template_mags, mags.flatten()))
all_template_errs = np.concatenate(
(all_template_errs, err_grid.flatten())
)
else:
# Create grid from the SN collection instead
all_phases, all_wls, all_mags, all_errs = [], [], [], []
for filt in self.filtlist:
phases, wl_grid, mags, err_grid = self._build_samples(filt)
all_phases = np.concatenate((all_phases, phases.flatten()))
all_wls = np.concatenate((all_wls, wl_grid.flatten()))
all_mags = np.concatenate((all_mags, mags.flatten()))
all_errs = np.concatenate((all_errs, err_grid.flatten()))
all_template_phases = all_phases
all_template_wls = all_wls
all_template_mags = all_mags
all_template_errs = all_errs
template_df = pd.DataFrame(
{
"Phase": all_template_phases,
"Wavelength": all_template_wls,
"Mag": all_template_mags,
"MagErr": all_template_errs,
}
)
return template_df
def _construct_median_grid(
self,
phasemin,
phasemax,
filtlist,
template_df,
log_transform,
plot=False,
):
"""
Takes as input the photometry from the sn set to normalize
and constructs a 2D template grid consisting of the median photometry
at each phase and wl step
"""
phase_grid_linear = np.arange(
phasemin, phasemax, 1 / 24.0
) # Grid of phases to iterate over, by hour
phase_grid = np.log(
phase_grid_linear + log_transform
) # Grid of phases in log space
wl_grid_linear = np.arange(
min(self.wle[f] for f in filtlist) - 500,
max(self.wle[f] for f in filtlist) + 500,
99.5,
) # Grid of wavelengths to iterate over, by 100 A
wl_grid = np.log10(wl_grid_linear)
mag_grid = np.empty((len(phase_grid), len(wl_grid)))
mag_grid[:] = np.nan
err_grid = np.copy(mag_grid)
for i in range(len(phase_grid)):
for j in range(len(wl_grid)):
### Get all data that falls within this phase + 5 days, and this wl +- 100 A
inds = template_df[
(
np.exp(template_df["Phase"]) - np.exp(phase_grid[i])
<= np.log(5.0)
)
& (np.exp(template_df["Phase"]) - np.exp(phase_grid[i] > 0.0))
& (abs(10 ** template_df["Wavelength"] - 10 ** wl_grid[j]) <= 500)
].index
if len(inds) > 0:
median_mag = np.median(template_df["Mag"][inds].values)
iqr = np.subtract(
*np.percentile(template_df["Mag"][inds], [75, 25])
)
mag_grid[i, j] = median_mag
err_grid[i, j] = iqr
mag_grid = self.interpolate_grid(mag_grid.T, phase_grid)
mag_grid = mag_grid.T
mag_grid = self.interpolate_grid(mag_grid, wl_grid, filter_window=31)
err_grid = self.interpolate_grid(err_grid.T, phase_grid)
err_grid = err_grid.T
err_grid = self.interpolate_grid(err_grid, wl_grid, filter_window=31)
X, Y = np.meshgrid(np.exp(phase_grid) - log_transform, 10**wl_grid)
Z = mag_grid.T
if plot:
Plot().plot_construct_grid(
gp_class=self,
X=X,
Y=Y,
Z=Z,
phase_grid=phase_grid,
mag_grid=mag_grid,
wl_grid=wl_grid,
err_grid=err_grid,
filtlist=filtlist,
grid_type="median",
)
return phase_grid, wl_grid, mag_grid, err_grid
def _construct_polynomial_grid(
self,
phasemin,
phasemax,
filtlist,
template_df,
log_transform,
plot=False,
):
"""
Takes as input the photometry from the sn set to normalize
and constructs a 2D template grid consisting of the polynomial fit
to the SN set to normalize photometry at each phase and wl step
"""
phase_grid_linear = np.arange(
phasemin, phasemax, 1 / 24.0
) # Grid of phases to iterate over, by hour
phase_grid = np.log(
phase_grid_linear + log_transform
) # Grid of phases in log space
wl_grid = np.arange(
np.log10(min(self.wle[f] for f in filtlist) - 500),
np.log10(max(self.wle[f] for f in filtlist) + 500),
0.01,
)
mag_grid = np.empty((len(phase_grid), len(wl_grid)))
mag_grid[:] = np.nan
err_grid = np.copy(mag_grid)
### Add an array of fake measurements to anchor the ends of the fit
anchor_inds_begin = np.where(
(abs(template_df["Phase"] - np.log(phasemin + log_transform)) < 1.0)
)[0]
if len(anchor_inds_begin) > 0:
anchor_mag_begin = min(template_df["Mag"][anchor_inds_begin].values)
else:
anchor_mag_begin = -4.0
anchor_inds_end = np.where(
(abs(template_df["Phase"] - np.log(phasemax + log_transform)) < 0.3)
)[0]
if len(anchor_inds_end) > 0:
anchor_mag_end = min(template_df["Mag"][anchor_inds_end].values)
else:
anchor_mag_end = -4.0
anchor_phases = np.asarray(
[
np.log(phasemin + log_transform),
np.log(phasemin + 2.5 + log_transform),
np.log(phasemax + log_transform),
]
)
anchor_mags = np.asarray(
[anchor_mag_begin - 1.0, anchor_mag_begin, anchor_mag_end - 1.0]
)
for j in range(len(wl_grid)):
### Get all data that falls within this wl +- 500 A
inds = template_df[
abs(10 ** template_df["Wavelength"] - 10 ** wl_grid[j]) <= 499
].index
if len(inds) > 0:
phases_to_fit = np.concatenate(
(template_df["Phase"][inds], anchor_phases)
)
mags_to_fit = np.concatenate((template_df["Mag"][inds], anchor_mags))
# if j == 0 or j == len(wl_grid) - 1:
# # At the wavelength boundary, anchor this with artificially lower mags
# mags_to_fit -= 2
errs_to_fit = np.concatenate(
(template_df["MagErr"][inds], np.ones(len(anchor_phases)) * 0.05)
)
fit_coeffs = np.polyfit(
phases_to_fit,
mags_to_fit,
3,
w=1
/ (
np.sqrt(
errs_to_fit**2
+ (
np.ones(
len(template_df["MagErr"][inds])
+ len(anchor_phases)
)
* 0.1
)
** 2
)
),
)
fit = np.poly1d(fit_coeffs)
grid_mags = fit(phase_grid)
mag_grid[:, j] = grid_mags
err_grid[:, j] = np.ones(len(phase_grid)) * np.median(
abs(template_df["Mag"][inds] - fit(template_df["Phase"][inds]))
)
### Interpolate over the wavelengths to get a complete 2D grid
mag_grid = self.interpolate_grid(mag_grid, wl_grid, filter_window=31)
err_grid = self.interpolate_grid(err_grid, wl_grid, filter_window=31)
X, Y = np.meshgrid(np.exp(phase_grid) - log_transform, 10**wl_grid)
Z = mag_grid.T
if plot:
Plot().plot_construct_grid(
gp_class=self, X=X, Y=Y, Z=Z, grid_type="polynomial"
)
return phase_grid, wl_grid, mag_grid, err_grid
def _subtract_data_from_grid(
self,
sn,
filtlist,
phase_grid,
wl_grid,
mag_grid,
err_grid,
plot=False,
):
"""
Takes the (shifted) photometry from a given SN and subtracts from it
the template grid constructed from either the median of all the normalization
SN photometry, or the polynomial fit to all the normalization SN photometry
Returns the phase and wavelength of each data point from the given SN, as well as
the residuals in its magnitude (or flux) and its uncertainty
"""
### Subtract off templates for each SN LC
residuals = []
for filt in filtlist:
if filt in sn.cube["ShiftedFilter"].values:
mags = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][
"ShiftedFlux"
].values
errs = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][
"ShiftedFluxerr"
].values
current_nondets = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][
"Nondetection"
].values
current_wls = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][
"LogShiftedWavelength"
].values
mags_from_peak = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][
"MagFromPeak"
].values
phases = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][
"LogPhase"
].values
else:
phases = []
# There's a bug that I have to track down where sometimes filters without data
# have a NaN as their redshift until data is read in, so the second clause
# of this if statement will skip over those filters
# (I don't know why this isn't caught above)
if len(phases) > 0 and not np.isnan(sn.info.get("z", 0)):
if plot:
_, ax = plt.subplots()
for i, phase in enumerate(phases):
### Get index of current phase in phase grid
### The phase corresponding to phase_ind is no more than the phase grid spacing away from the true phase being measured
phase_ind = np.argmin(abs(np.exp(phase_grid) - np.exp(phase)))
wl_ind = np.argmin(abs(wl_grid - current_wls[i]))
if np.isnan(mag_grid[phase_ind, wl_ind]):
logger.warning(
f"NaN Found: phase {np.exp(phase)}, wl {10 ** wl_grid[wl_ind]}"
)
continue
if np.isinf(mags[i] - mag_grid[phase_ind, wl_ind]):
logger.warning(
f"Infinity found: phase {np.exp(phase)}, wl {10 ** wl_grid[wl_ind]}"
)
continue
residuals.append(
{
"Filter": filt,
"Phase": phase,
"Wavelength": current_wls[i],
"MagResidual": mags[i] - mag_grid[phase_ind, wl_ind],
"MagErr": errs[i],
"Mag": mags_from_peak[i],
"Nondetection": current_nondets[i],
}
)
if plot:
plt.errorbar(
phase,
mags[i] - mag_grid[phase_ind, wl_ind],
yerr=np.sqrt(
errs[i] ** 2 + err_grid[phase_ind, wl_ind] ** 2
),
marker="o",
color="k",
)
plt.errorbar(
phase,
mags[i],
yerr=errs[i],
fmt="o",
color=colors.get(filt, "k"),
)
if plot:
# NOTE: The wl ind changes with each data point,
# therefore this is representation is not fully accurate
Plot().plot_subtract_data_from_grid(
sn_class=sn,
phase_grid=phase_grid,
mag_grid=mag_grid,
wl_ind=wl_ind,
filt=filt,
ax=ax,
)
plt.show()
return pd.DataFrame(residuals)
def _build_test_wavelength_phase_grid_from_photometry(
self, measured_wavelengths, measured_phases, wl_grid, phase_grid
):
"""
Function to build a uniform grid of wavelengths and phases given photometry
in the form of measured wavelengths and phases as well as a wavelength and phase grid
corresponding to the template SED grid
"""
waves_to_predict = np.unique(measured_wavelengths)
diffs = abs(
np.subtract.outer(10**wl_grid, 10**waves_to_predict)
) # The difference between our measurement wavelengths and the wl grid
phases_to_predict = np.unique(measured_phases)
### Compare the wavelengths of our measured filters to those in the wl grid
### and fit for those grid wls that are within 500 A of one of our measurements
wl_inds_fitted = np.unique(np.where((diffs < 500.0))[0])
phase_inds_fitted = np.unique(
np.where(
(phase_grid >= min(phases_to_predict))
& (phase_grid <= max(phases_to_predict))
)[0]
)
if len(phase_inds_fitted) == 0:
return [], [], [], [], None
linear_phases = np.exp(phase_grid[phase_inds_fitted]) - self.log_transform
phases = np.log(linear_phases - min(linear_phases) + 0.1)
x, y = np.meshgrid(phases, wl_grid[wl_inds_fitted])
return x, y, wl_inds_fitted, phase_inds_fitted, min(linear_phases)
[docs]
def run_gp_on_full_sample(
self,
plot=False,
subtract_median=False,
subtract_polynomial=False,
):
"""
Run the Gaussian Process Regression fitting routine on the full sample
at once. Does not individually fit each transient, but instead fits
all photometry in the input sample simultaneously.
Args:
plot (bool, optional): Show intermediate plots. Defaults to False.
subtract_median (bool, optional): Fit and subtract off a median function
to calculate residuals. Defaults to False.
subtract_polynomial (bool, optional): Fit and subtract off a polynomial
function to calculate residuals. One of `subtract_median` and
`subtract_polynomial` must be True. Defaults to False.
Raises:
Exception: Must toggle either subtract_median or subtract_polynomial as True
to run GP3D.
Returns:
list: A list containing the Gaussian process model and its
uncertainty, both as an mxn array
np.ndarray: An array containing the processed magnitudes from
the template grid
np.ndarray: An array containing the log-transformed phases
np.ndarray: An array containing the log-transformed wavelengths
"""
template_df = self._process_dataset(set_to_normalize=self.set_to_normalize)
if subtract_polynomial:
phase_grid, wl_grid, mag_grid, err_grid = self._construct_polynomial_grid(
self.phasemin,
self.phasemax,
self.filtlist,
template_df,
log_transform=self.log_transform,
plot=plot,
)
elif subtract_median:
phase_grid, wl_grid, mag_grid, err_grid = self._construct_median_grid(
self.phasemin,
self.phasemax,
self.filtlist,
template_df,
log_transform=self.log_transform,
plot=plot,
)
else:
raise Exception(
"Must toggle either subtract_median or subtract_polynomial as True to run GP3D"
)
x_input, y_input, err, raw_mags = [], [], [], []
for sn in self.collection.sne:
if hasattr(sn, "cube"):
residuals = self._subtract_data_from_grid(
sn,
self.filtlist,
phase_grid,
wl_grid,
mag_grid,
err_grid,
plot=False, # TODO: are we sure we want to hard-code False for grid subtraction?
)
if len(residuals) == 0:
continue
phase_residuals_linear = (
np.exp(residuals["Phase"].values) - self.log_transform
)
phases_to_fit = np.log(
phase_residuals_linear - min(phase_residuals_linear) + 0.1
)
if not len(x_input):
x_input = np.vstack(
(phases_to_fit, residuals["Wavelength"].values)
).T
else:
x_input = np.concatenate(
[
x_input,
np.vstack(
(phases_to_fit, residuals["Wavelength"].values)
).T,
]
)
y_input = np.concatenate([y_input, residuals["MagResidual"].values])
raw_mags = np.concatenate([raw_mags, residuals["Mag"].values])
err = np.concatenate([err, residuals["MagErr"].values])
if isinstance(self.kernel, Kernel):
gaussian_process = GaussianProcessRegressor(
kernel=self.kernel.kernel, alpha=err, n_restarts_optimizer=10
)
else:
gaussian_process = GaussianProcessRegressor(
kernel=self.kernel, alpha=err, n_restarts_optimizer=10
)
gaussian_process.fit(x_input, y_input)
x, y, wl_inds_fitted, phase_inds_fitted, phase_offset = (
self._build_test_wavelength_phase_grid_from_photometry(
x_input[:, 1], x_input[:, 0], wl_grid, phase_grid
)
)
test_prediction, std_prediction = gaussian_process.predict(
np.vstack((x.ravel(), y.ravel())).T, return_std=True
)
template_mags = []
for wl_ind in wl_inds_fitted:
for phase_ind in phase_inds_fitted:
template_mags.append(mag_grid[phase_ind, wl_ind])
template_mags = np.asarray(template_mags).reshape((len(x), -1))
final_prediction = test_prediction.reshape((len(x), -1)) + template_mags
final_std_prediction = std_prediction.reshape((len(x), -1)) + template_mags
### Convert to mags from peak
for i, col in enumerate(final_prediction[:,]):
wl = 10 ** wl_grid[wl_inds_fitted][i]
zp = (10**-23 * 3e18 / wl) * 1e11
shifted_mags = -1 * (
(col + np.log10(zp * 1e-11) - np.log10(10**-23 * 3e18 / 5000)) / -0.4
)
final_prediction[i] = shifted_mags
### Map predicted SED surface and uncertainty to the phase, wl grids
gp_grid = np.empty((len(wl_grid), len(phase_grid)))
gp_grid[:] = np.nan
for i, col in enumerate(final_prediction[:,]):
current_wl_grid_ind = wl_inds_fitted[i]
for j in range(len(col)):
current_phase_grid_ind = phase_inds_fitted[j]
gp_grid[current_wl_grid_ind, current_phase_grid_ind] = col[j]
gp_grid_std = np.empty((len(wl_grid), len(phase_grid)))
gp_grid_std[:] = np.nan
for i, col in enumerate(final_std_prediction[:,]):
current_wl_grid_ind = wl_inds_fitted[i]
for j in range(len(col)):
current_phase_grid_ind = phase_inds_fitted[j]
gp_grid_std[current_wl_grid_ind, current_phase_grid_ind] = col[j]
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
Plot().plot_construct_grid(
gp_class=self,
X=np.exp(x) + phase_offset - 0.1,
Y=10 ** (y),
Z=final_prediction,
ax=ax,
Z_lower=final_prediction - 1.96 * final_std_prediction,
Z_upper=final_prediction + 1.96 * final_std_prediction,
grid_type="final",
)
ax.scatter(
np.exp(x_input[:, 0]) - self.log_transform,
10 ** x_input[:, 1],
raw_mags,
marker="o",
color="k",
)
plt.show()
return [gp_grid, gp_grid_std], mag_grid, phase_grid, wl_grid
[docs]
def optimize_hyperparams(self, subtract_median=False, subtract_polynomial=False):
"""
Optimize the Gaussian Process Regression kernel hyperparameters by fitting
each transient in the sample individually.
This is normally run before the `predict` method to retrieve the optimized
kernel hyperparameters for the input sample. These hyperparameters are normally
then fixed in the kernel, and the `predict` method is run using the fixed kernel.
Args:
subtract_median (bool, optional): Fit and subtract off a median function
to calculate residuals. Defaults to False.
subtract_polynomial (bool, optional): Fit and subtract off a polynomial
function to calculate residuals. One of `subtract_median` and
`subtract_polynomial` must be True. Defaults to False.
Raises:
Exception: Must toggle either subtract_median or subtract_polynomial as True
to run GP3D.
Returns:
list: List of optimized kernel hyperparameters.
"""
template_df = self._process_dataset(set_to_normalize=self.set_to_normalize)
kernel_params = []
if subtract_polynomial:
phase_grid, wl_grid, mag_grid, err_grid = self._construct_polynomial_grid(
self.phasemin,
self.phasemax,
self.filtlist,
template_df,
log_transform=self.log_transform,
plot=False,
)
elif subtract_median:
phase_grid, wl_grid, mag_grid, err_grid = self._construct_median_grid(
self.phasemin,
self.phasemax,
self.filtlist,
template_df,
log_transform=self.log_transform,
plot=False,
)
else:
raise Exception(
"Must toggle either subtract_median or subtract_polynomial as True to run GP3D"
)
for sn in self.collection.sne:
if hasattr(sn, "cube"):
residuals = self._subtract_data_from_grid(
sn,
self.filtlist,
phase_grid,
wl_grid,
mag_grid,
err_grid,
plot=False, # TODO: are we sure we want to hard-code False for grid subtraction?
)
if len(residuals) == 0:
continue
phase_residuals_linear = (
np.exp(residuals["Phase"].values) - self.log_transform
)
phases_to_fit = np.log(
phase_residuals_linear - min(phase_residuals_linear) + 0.1
)
x = np.vstack((phases_to_fit, residuals["Wavelength"].values)).T
y = residuals["MagResidual"].values
if len(y) > 1:
# We have enough points to fit
err = residuals["MagErr"].values
if isinstance(self.kernel, Kernel):
gaussian_process = GaussianProcessRegressor(
kernel=self.kernel.kernel,
alpha=err,
n_restarts_optimizer=10,
)
else:
gaussian_process = GaussianProcessRegressor(
kernel=self.kernel, alpha=err, n_restarts_optimizer=10
)
gaussian_process.fit(x, y)
kernel_params.append(gaussian_process.kernel_.theta)
optimized_kernel_hyperparams = np.asarray(
[
np.median([k[i] for k in kernel_params])
for i in range(len(kernel_params[0]))
]
)
if isinstance(self.kernel, Kernel):
self.kernel.recursively_set_params(optimized_kernel_hyperparams, "fixed")
else:
self.kernel.set_params(
**{
"length_scale": optimized_kernel_hyperparams,
"length_scale_bounds": "fixed",
}
)
return kernel_params
def _iteratively_warp_sed(
self,
residuals: pd.DataFrame,
test_prediction_reshaped: np.ndarray,
wls_fit: np.ndarray,
phases_fit: np.ndarray,
sn: SN,
convergence_threshold: float = 1.0,
niter: float = 100,
):
"""
Compare photometry to prediction, iteratively warping the
predicted flux until it matches the photometry within
convergence_threshold * error for each data point
"""
phases_to_iterate_over = np.log(
np.arange(
min(np.exp(residuals["Phase"].values) - self.log_transform),
max(np.exp(residuals["Phase"].values) - self.log_transform),
0.5,
)
)
for phase in phases_to_iterate_over:
current_lc_inds = np.where((abs(residuals["Phase"] - phase) <= 0.1))[0]
if (
len(current_lc_inds) > 0
and len({filt for filt in residuals["Filter"][current_lc_inds]}) > 1
):
# We have data in at least two filters at this epoch
current_lc_cube = residuals[abs(residuals["Phase"] - phase) <= 0.1]
current_lc = current_lc_cube["Mag"].values
current_lc = np.concatenate(([0.0], current_lc, [current_lc[-1] / 2]))
current_lc_errs = np.concatenate(
([0.0], current_lc_cube["MagErr"].values, [0.0])
)
errors = np.ones(len(current_lc_inds)) * 100.0
errors = np.concatenate(([1.0], errors, [1.0]))
wls = [
np.log10(self.wle[filt] * (1 + sn.info.get("z", 0)))
for filt in current_lc_cube["Filter"].values
]
wls = np.concatenate(
(
[np.log10(10 ** min(wls) - 500)],
wls,
[np.log10(10 ** max(wls) + 500)],
)
)
sed_inds_at_this_phase = np.where(
(wls_fit > min(wls)) & (wls_fit < max(wls))
)[0]
if len(sed_inds_at_this_phase) == 0:
### Redshifted filter central wavelengths outside the bounds of the filters fit by the GP
continue
phase_ind = np.argmin(abs(phases_fit - phase))
prediction_slice = np.copy(
test_prediction_reshaped[sed_inds_at_this_phase, phase_ind]
)
n = 0
for i in range(niter):
if (
all(
[
abs(1 - errors[i])
<= convergence_threshold * current_lc_errs[i]
for i in range(len(errors))
]
)
or n == niter
):
break
for j, filt in enumerate(current_lc_cube["Filter"]):
real_flux_inds = (
np.where((current_lc_cube["Filter"] == filt))[0] + 1
)
if len(real_flux_inds) > 1:
real_mag = np.average(current_lc[real_flux_inds])
else:
real_mag = current_lc[real_flux_inds][0]
wl_for_filt = np.log10(
self.wle[filt] * (1 + sn.info.get("z", 0))
)
wl_ind = np.argmin(
abs(wls_fit[sed_inds_at_this_phase] - wl_for_filt)
)
predicted_mag = prediction_slice[wl_ind]
diff = real_mag - predicted_mag
if predicted_mag > 0:
error = 1 + diff
else:
error = 1 - diff
errors[j + 1] = error
if any(
[
abs(1 - errors[j])
> convergence_threshold * current_lc_errs[j]
for j in range(len(errors))
]
):
### Warp the SED by an interpolation of the error across wavelength
### and rerun the loop
if n < niter:
n += 1
if any(errors > 1e3) or any(np.isinf(errors)):
n = niter
interp_error_fn = interp1d(wls, errors)
interp_to_warp_by = interp_error_fn(
wls_fit[sed_inds_at_this_phase]
)
prediction_slice *= interp_to_warp_by
if n == niter:
logger.debug(
f"Couldnt iterate to match flux: {sn.name}, {np.exp(phase) - self.log_transform}"
)
if n < niter:
# Put the warped SED back in the predicted datacube
test_prediction_reshaped[sed_inds_at_this_phase, phase_ind] = (
prediction_slice
)
return test_prediction_reshaped
def _sample_predicted_sed(self, mean_prediction, std_prediction):
"""Randomly sample the predicted SED for better uncertainty estimate in final GP surface"""
### Draw random sigma between -1 and 1
sigma = truncnorm.rvs(-1, 1, loc=0, scale=1, size=1)[0]
sampled_prediction = mean_prediction + (sigma * std_prediction)
return sampled_prediction
def _smooth_predicted_model(
self, model_array: np.ndarray, window_size: int, transpose: bool = False
):
if transpose:
model_array = model_array.T
model_array_smoothed = np.empty(model_array.shape)
for i, col in enumerate(model_array):
if window_size % 2 == 0:
window_size += 1 # Must be odd
# Use astropy convolve function to handle NaNs
model_array_smoothed[i, :] = convolve(
col, np.ones(window_size) / window_size, boundary="extend"
) # Boxcar smoothing
if transpose:
model_array_smoothed = model_array_smoothed.T
return model_array_smoothed
[docs]
def run_gp_individually(
self,
plot=False,
subtract_median=False,
subtract_polynomial=False,
interactive=False,
run_diagnostics=False,
save_individual_fits=False,
):
"""
Run the Gaussian Process Regression fitting routine on each transient
in the full sample individually. This produces a bespoke Gaussian process
model for each transient, constructing a full 3-dimension SED surface.
These surfaces are then randomly sampled and to be used in the construction
of the final template model surface.
Args:
plot (bool, optional): Show intermediate plots. Defaults to False.
subtract_median (bool, optional): Fit and subtract off a median function
to calculate residuals. Defaults to False.
subtract_polynomial (bool, optional): Fit and subtract off a polynomial
function to calculate residuals. One of `subtract_median` and
`subtract_polynomial` must be True. Defaults to False.
interactive (bool, optional): Interactively choose which fits to use in the
creation of the final Gaussian process model. If True, sets `plot` to True
as well. Defaults to False.
run_diagnostics (bool, optional): Run diagnostic tests on the fitting to identify
data points or regions of poor fit quality. Defaults to False.
save_individual_fits (bool, optional): Save the GP fits to each individual
SN object separately. If True, will save the models to the default location.
WARNING: This can be a lot of data. Defualts to False.
Raises:
Exception: Must toggle either subtract_median or subtract_polynomial as True
to run GP3D.
Returns:
list: A list of random samples from the Gaussian process distribution. Used to
construct final 3-dimensional templates.
np.ndarray: An array containing the processed magnitudes from
the template grid
np.ndarray: An array containing the log-transformed phases
np.ndarray: An array containing the log-transformed wavelengths
"""
if interactive:
plot = True
template_df = self._process_dataset(set_to_normalize=self.set_to_normalize)
gaussian_processes = []
if subtract_polynomial:
phase_grid, wl_grid, mag_grid, err_grid = self._construct_polynomial_grid(
self.phasemin,
self.phasemax,
self.filtlist,
template_df,
log_transform=self.log_transform,
plot=plot,
)
elif subtract_median:
phase_grid, wl_grid, mag_grid, err_grid = self._construct_median_grid(
self.phasemin,
self.phasemax,
self.filtlist,
template_df,
log_transform=self.log_transform,
plot=plot,
)
else:
raise Exception(
"Must toggle either subtract_median or subtract_polynomial as True to run GP3D"
)
for sn in self.collection.sne:
if not hasattr(sn, "cube"):
continue
residuals = self._subtract_data_from_grid(
sn,
self.filtlist,
phase_grid,
wl_grid,
mag_grid,
err_grid,
plot=False, # TODO: are we sure we want to hard-code False for grid subtraction?
)
if len(residuals) == 0:
continue
phase_residuals_linear = (
np.exp(residuals["Phase"].values) - self.log_transform
)
phases_to_fit = np.log(
phase_residuals_linear - min(phase_residuals_linear) + 0.1
)
x = np.vstack((phases_to_fit, residuals["Wavelength"].values)).T
y = residuals["MagResidual"].values
if len(y) > 1:
# We have enough points to fit
err = residuals["MagErr"].values
if isinstance(self.kernel, Kernel):
gaussian_process = GaussianProcessRegressor(
kernel=self.kernel.kernel, alpha=err, n_restarts_optimizer=10
)
else:
gaussian_process = GaussianProcessRegressor(
kernel=self.kernel, alpha=err, n_restarts_optimizer=10
)
gaussian_process.fit(x, y)
if plot:
_, ax = Plot().create_empty_subplot()
filts_fitted = []
for filt in self.filtlist:
shifted_mjd = sn.cube[sn.cube["Filter"] == filt][
"Phase"
].values.astype(float)
shifted_mjd = sn.log_transform_time(
shifted_mjd, phase_start=self.log_transform
)
inds_to_fit = np.where(
(shifted_mjd > np.log(self.phasemin + self.log_transform))
& (shifted_mjd < np.log(self.phasemax + self.log_transform))
)[0]
if len(inds_to_fit) > 0:
filts_fitted.append(filt)
if plot or run_diagnostics:
test_times_linear = np.arange(
min(phase_residuals_linear),
max(phase_residuals_linear),
1.0 / 24,
)
test_times = np.log(
test_times_linear - min(phase_residuals_linear) + 0.1
)
test_waves = np.ones(len(test_times)) * np.log10(
self.wle[filt] * (1 + sn.info.get("z", 0))
)
### Trying to convert back to normalized magnitudes here
wl_ind = np.argmin(
abs(
10**wl_grid - self.wle[filt] * (1 + sn.info.get("z", 0))
)
)
template_mags = []
for i in range(len(test_times_linear)):
j = np.argmin(
abs(
np.exp(phase_grid)
- self.log_transform
- test_times_linear[i]
)
)
template_mags.append(mag_grid[j, wl_ind])
template_mags = np.asarray(template_mags)
test_prediction, std_prediction = gaussian_process.predict(
np.vstack((test_times, test_waves)).T, return_std=True
)
test_times = (
np.exp(test_times) + min(phase_residuals_linear) - 0.1
)
residuals_for_filt = residuals[
(residuals["Filter"] == filt)
& (
residuals["Phase"]
> np.log(self.phasemin + self.log_transform)
)
& (
residuals["Phase"]
< np.log(self.phasemax + self.log_transform)
)
& (residuals["Nondetection"] == False)
]
if len(inds_to_fit) > 0:
if plot:
Plot().plot_run_gp_overlay(
ax=ax,
test_times=test_times,
test_prediction=test_prediction,
std_prediction=std_prediction,
template_mags=template_mags,
residuals=residuals_for_filt,
log_transform=self.log_transform,
filt=filt,
sn=sn,
)
if run_diagnostics:
d = Diagnostic()
d.identify_outlier_points(
filt,
test_times,
test_prediction + template_mags,
std_prediction,
np.exp(residuals_for_filt["Phase"].values)
- self.log_transform,
residuals_for_filt["Mag"].values,
residuals_for_filt["MagErr"].values,
)
d.check_late_time_slope(
filt,
test_times,
test_prediction + template_mags,
np.exp(residuals_for_filt["Phase"].values)
- self.log_transform,
)
if interactive:
use_for_template = input(
"Use this fit to construct a template? y/n"
)
x, y, wl_inds_fitted, phase_inds_fitted, phase_offset = (
self._build_test_wavelength_phase_grid_from_photometry(
residuals["Wavelength"].values,
residuals["Phase"].values,
wl_grid,
phase_grid,
)
)
if len(x) == 0:
continue
try:
test_prediction, std_prediction = gaussian_process.predict(
np.vstack((x.ravel(), y.ravel())).T, return_std=True
)
except Exception as e:
logger.warning(f"WARNING: BROKEN FIT FOR {sn.name}", exc_info=e)
continue
test_prediction = np.asarray(test_prediction)
template_mags = []
for wl_ind in wl_inds_fitted:
for phase_ind in phase_inds_fitted:
template_mags.append(mag_grid[phase_ind, wl_ind])
###NOTE: Some of these template mags are NaNs
template_mags = np.asarray(template_mags).reshape((len(x), -1))
### Put the fitted wavelengths back in the right spot on the grid
### and append to the gaussian processes array
test_prediction_reshaped = (
test_prediction.reshape((len(x), -1)) + template_mags
)
### Convert to mags from peak
for i, col in enumerate(test_prediction_reshaped[:,]):
wl = 10 ** wl_grid[wl_inds_fitted][i]
zp = (10**-23 * 3e18 / wl) * 1e11
shifted_peak_mag = np.log10(
sn.zps[sn.info["peak_filt"]]
* 1e-11
* 10 ** (-0.4 * sn.info["peak_mag"])
)
shifted_mags = -1 * (
(np.log10(10 ** (col + shifted_peak_mag) / (zp * 1e-11)) / -0.4)
- sn.info["peak_mag"]
)
test_prediction_reshaped[i] = shifted_mags
test_prediction_reshaped = self._iteratively_warp_sed(
residuals,
test_prediction_reshaped,
wl_grid[wl_inds_fitted],
phase_grid[phase_inds_fitted],
sn,
convergence_threshold=1.0,
)
std_prediction_reshaped = (
std_prediction.reshape((len(x), -1)) + template_mags
)
test_prediction_smoothed = self._smooth_predicted_model(
test_prediction_reshaped,
window_size=max(
int(
round(
test_prediction_reshaped.shape[0]
/ (2 * len(filts_fitted)),
0,
)
),
5,
), # Window size of approximately half a filter length scale
transpose=True,
)
test_prediction_smoothed = self._smooth_predicted_model(
test_prediction_smoothed,
window_size=max(
int(
round(
test_prediction_smoothed.shape[1]
/ (len(phase_grid / (5 * 24))),
0,
)
),
5,
), # Window size of approximate a day
)
std_prediction_smoothed = self._smooth_predicted_model(
std_prediction_reshaped,
window_size=max(
int(
round(
std_prediction_reshaped.shape[0]
/ (2 * len(filts_fitted)),
0,
)
),
5,
),
transpose=True,
)
gp_grid = np.empty((len(wl_grid), len(phase_grid)))
gp_grid[:] = np.nan
for i, col in enumerate(test_prediction_smoothed[:,]):
# for i, col in enumerate(test_prediction_reshaped[:,]):
current_wl_grid_ind = wl_inds_fitted[i]
for j in range(len(col)):
current_phase_grid_ind = phase_inds_fitted[j]
gp_grid[current_wl_grid_ind, current_phase_grid_ind] = col[j]
gp_grid_std = np.empty((len(wl_grid), len(phase_grid)))
gp_grid_std[:] = np.nan
# for i, col in enumerate(std_prediction_reshaped[:,]):
for i, col in enumerate(std_prediction_smoothed[:,]):
current_wl_grid_ind = wl_inds_fitted[i]
for j in range(len(col)):
current_phase_grid_ind = phase_inds_fitted[j]
gp_grid_std[current_wl_grid_ind, current_phase_grid_ind] = col[
j
]
if plot:
Plot().plot_run_gp_surface(
gp_class=self,
x=np.exp(x) + phase_offset - 0.1,
y=10 ** (y),
# test_prediction_reshaped=test_prediction_reshaped,
test_prediction_reshaped=test_prediction_smoothed,
)
if run_diagnostics:
d = Diagnostic()
d.check_gradient_between_filters(
[self.wle[f] for f in filts_fitted],
np.exp(phase_grid[phase_inds_fitted]) - self.log_transform,
10 ** (wl_grid[wl_inds_fitted]),
# test_prediction_reshaped,
test_prediction_smoothed,
# std_prediction_reshaped,
std_prediction_smoothed,
[-15.0, 0.0, 50.0],
)
if "UVM2" in filts_fitted:
d.check_uvm2_flux(
np.exp(phase_grid[phase_inds_fitted])
- self.log_transform,
10 ** (wl_grid[wl_inds_fitted]),
# test_prediction_reshaped,
test_prediction_smoothed,
# std_prediction_reshaped,
std_prediction_smoothed,
[-15.0, 0.0, 50.0],
)
if not interactive:
use_for_template = "y"
if use_for_template == "y":
for i in range(round(np.log(len(residuals)))):
random_sample = self._sample_predicted_sed(gp_grid, gp_grid_std)
gaussian_processes.append(random_sample)
if save_individual_fits:
snmodel = SNModel(
surface=gaussian_process,
template_mags=template_mags.T,
phase_grid=np.exp(phase_grid[phase_inds_fitted]) - self.log_transform,
wl_grid=10**(wl_grid[wl_inds_fitted]),
filters_fit=filts_fitted,
sn=sn,
norm_set=self.set_to_normalize,
log_transform=self.log_transform,
)
snmodel.save_fits()
return gaussian_processes, mag_grid, phase_grid, wl_grid
[docs]
def predict(
self,
plot=False,
subtract_median=False,
subtract_polynomial=False,
run_diagnostics=False,
fit_separately=True,
):
"""
Generate a Gaussian Process Regression model of the input transient sample.
Uses the specified normalization sample and all other initialized parameters
to process and fit the data to produce the model.
Args:
plot (bool, optional): Show intermediate plots. Defaults to False.
subtract_median (bool, optional): Fit and subtract off a median function
to calculate residuals. Defaults to False.
subtract_polynomial (bool, optional): Fit and subtract off a polynomial
function to calculate residuals. One of `subtract_median` and
`subtract_polynomial` must be True. Defaults to False.
run_diagnostics (bool, optional): Run diagnostic tests on the fitting to identify
data points or regions of poor fit quality. Defaults to False.
fit_separately (bol, optional): Fit each transient separately, or together
as a group. Controls whether `run_gp_individually` or `run_gp_on_full_sample`
is called to generate the predictive Gaussian Process Regression model.
Defaults to True.
Returns:
SNModel: An SNModel object containing the final, 3-dimensional Gaussian Process
Regression template model of the input transient sample.
"""
self._prepare_data()
if not fit_separately:
gaussian_process, template_mags, phase_grid, wl_grid = (
self.run_gp_on_full_sample(
plot=plot,
subtract_polynomial=subtract_polynomial,
subtract_median=subtract_median,
)
)
surface = SurfaceArray(
surface=np.asarray(gaussian_process),
phase_grid=phase_grid,
wl_grid=wl_grid,
kernel=self.kernel,
)
snmodel = SNModel(
phase_grid=np.exp(phase_grid) - self.log_transform,
wl_grid=10**wl_grid,
filters_fit=self.filtlist,
surface=surface,
template_mags=template_mags,
sncollection=self.collection,
norm_set=self.set_to_normalize,
log_transform=self.log_transform,
)
return snmodel
else:
### We're fitting each SN individually and then median combining the full 2D GP
gaussian_processes, mag_grid, phase_grid, wl_grid = (
self.run_gp_individually(
plot=plot,
subtract_median=subtract_median,
subtract_polynomial=subtract_polynomial,
run_diagnostics=run_diagnostics,
)
)
X, Y = np.meshgrid(np.exp(phase_grid) - self.log_transform, 10**wl_grid)
median_gp = np.nanmedian(np.dstack(gaussian_processes), -1)
median_gp = self.interpolate_grid(median_gp.T, wl_grid, filter_window=31)
for i, col in enumerate(median_gp.T):
median_gp[:, i] = medfilt(col, kernel_size=51)
median_gp = median_gp.T
median_gp = self.interpolate_grid(median_gp, phase_grid, filter_window=171)
for i, col in enumerate(median_gp):
median_gp[i, :] = medfilt(col, kernel_size=51)
iqr_grid = np.nanstd(np.dstack(gaussian_processes), -1)
iqr_grid = self.interpolate_grid(iqr_grid.T, wl_grid, filter_window=31)
for i, col in enumerate(iqr_grid.T):
iqr_grid[:, i] = medfilt(col, kernel_size=51)
iqr_grid = iqr_grid.T
iqr_grid = self.interpolate_grid(iqr_grid, phase_grid, filter_window=171)
for i, col in enumerate(iqr_grid):
iqr_grid[i, :] = medfilt(col, kernel_size=51)
Z = median_gp
Plot().plot_construct_grid(
gp_class=self,
X=X,
Y=Y,
Z=Z,
Z_lower=Z - iqr_grid,
Z_upper=Z + iqr_grid,
grid_type="final",
)
plt.show()
surface = SurfaceArray(
surface=np.asarray([median_gp, iqr_grid]),
phase_grid=phase_grid,
wl_grid=wl_grid,
kernel=self.kernel,
)
snmodel = SNModel(
surface=surface,
template_mags=mag_grid,
phase_grid=np.exp(phase_grid) - self.log_transform,
wl_grid=10 ** (wl_grid),
filters_fit=self.filtlist,
sncollection=self.collection,
norm_set=self.set_to_normalize,
log_transform=self.log_transform,
)
return snmodel