Source code for caat.GP3D

import logging
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from astropy.convolution import convolve
from scipy.interpolate import interp1d
from scipy.signal import medfilt, savgol_filter
from scipy.stats import truncnorm
from sklearn.gaussian_process import GaussianProcessRegressor

from .DataCube import DataCube
from .Diagnostics import Diagnostic
from .GP import GP
from .Kernels import Kernel
from .Plot import Plot
from .SN import SN
from .SNCollection import SNCollection, SNType
from .SNModel import SNModel, SurfaceArray
from .utils import colors

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

warnings.filterwarnings("ignore")


[docs] class GP3D(GP): """ Class to perform GP fitting simultaneously across wavelength and phase for a given collection of SNe. Reads in a list of SNe to fit, as well as a collection of SNe to normalize that sample against, along with a set of optional flags for different fitting routines and parameters. Each SN in the science and comparison samples should already have its photometric data pre-processed using the routines available in the `SN` class, run within the `DataCube` class, and read directly from the SN datacube files created by the latter. """ def __init__( self, collection: SNCollection | SNType, kernel: Kernel, filtlist: list, phasemin: int, phasemax: int, log_transform: float, set_to_normalize: SNCollection | SNType | None = None, mangle_sed: bool = False, ): """ Initialize a GP3D object with the following arguments. Args: collection (SNCollection | SNType): A collection of SN objects to fit kernel (Kernel): A Kernel object used in the Gaussian Process Regression. filtlist (list): A list of filters to fit. phasemin (int): The minimum phase, relative to peak brightness, to be fit. phasemax (int): The maximum phase, relative to peak brightness, to be fit. log_transform (float): The offset in the log transform. Must be larger than `phasemin`. Effectively controls the light curve "stretch" in log space. set_to_normalize (SNCollection | SNType | None, optional): The colleciton of transients to normalize `collection` to. If `None`, will use `collection`. Defaults to None. mangle_sed (bool, optional): Use the calculated mangled `DataCube` for each `SN` object in `collection` and `set_to_normalize`. If True, will exclude `SN` objects without a mangled `DataCube`. Defaults to False. """ super().__init__( collection, kernel, filtlist, phasemin, phasemax, log_transform ) self.set_to_normalize = set_to_normalize self.mangle_sed = mangle_sed self._prepare_data() def _prepare_data(self): """ Use the flags set in __init__ to filter the pandas dataframes for each SN in the science and control samples """ for collection in [self.collection, self.set_to_normalize]: for sn in collection.sne: # Read the correct cube based on self.mangle_sed if self.mangle_sed: data_cube_filename = os.path.join( sn.base_path, sn.classification, sn.subtype, sn.name, sn.name + "_datacube_mangled.csv", ) else: data_cube_filename = os.path.join( sn.base_path, sn.classification, sn.subtype, sn.name, sn.name + "_datacube.csv", ) if os.path.exists(data_cube_filename): cube = pd.read_csv(data_cube_filename) else: # For now, we'll just construct it datacube = DataCube(sn=sn) datacube.construct_cube() cube = datacube.cube # Drop rows that are out of the phase range inds_to_drop_phase = cube.loc[ (cube["Phase"] < self.phasemin) | (cube["Phase"] > self.phasemax) ].index cube = cube.drop(inds_to_drop_phase).reset_index(drop=True) # Drop nondetections farther than 2 days away from first/last detection try: min_phase = min( cube.loc[cube["Nondetection"] == False]["Phase"].values ) inds_to_drop_nondets_before_first_det = cube.loc[ (cube["Nondetection"] == True) & (cube["Phase"] < (min_phase - 2.0)) ].index cube = cube.drop(inds_to_drop_nondets_before_first_det).reset_index( drop=True ) except ValueError: # no values pass try: max_phase = max( cube.loc[cube["Nondetection"] == False]["Phase"].values ) inds_to_drop_nondets_after_last_det = cube.loc[ (cube["Nondetection"] == True) & (cube["Phase"] > (max_phase + 2.0)) ].index cube = cube.drop(inds_to_drop_nondets_after_last_det).reset_index( drop=True ) except ValueError: # no values pass # Drop rows corresponding to filters not in the user-provided filter list inds_to_drop_filts = cube.loc[~cube["Filter"].isin(self.filtlist)].index cube = cube.drop(inds_to_drop_filts).reset_index(drop=True) # Log transform the data (as a separate column) cube["LogPhase"] = np.log( cube["Phase"].values.astype(float) + self.log_transform ) cube["LogWavelength"] = np.log10( cube["Wavelength"].values.astype(float) ) cube["LogShiftedWavelength"] = np.log10( cube["ShiftedWavelength"].values.astype(float) ) # Drop nondetections that are within the phase range and less constraining # than the first or last detection in each filter try: min_flux_before_peak = min( cube.loc[(cube["Nondetection"] == False) & (cube["Phase"] < 0)][ "Flux" ].values ) inds_to_drop_nondets_before_peak = cube.loc[ (cube["Nondetection"] == True) & (cube["Phase"] < 0) & (cube["Flux"] > min_flux_before_peak) ].index cube = cube.drop(inds_to_drop_nondets_before_peak).reset_index( drop=True ) except ValueError: # No values pre-peak, so pass pass try: min_flux_after_peak = min( cube.loc[(cube["Nondetection"] == False) & (cube["Phase"] > 0)][ "Flux" ].values ) inds_to_drop_nondets_after_peak = cube.loc[ (cube["Nondetection"] == True) & (cube["Phase"] > 0) & (cube["Flux"] > min_flux_after_peak) ].index cube = cube.drop(inds_to_drop_nondets_after_peak).reset_index( drop=True ) except ValueError: # No values post-peak, so pass pass # Drop nondetections between the first and last detection cube_only_dets = cube.loc[cube["Nondetection"] == False] if len(cube_only_dets["Phase"].values) > 0: first_detection = min(cube_only_dets["Phase"].values) last_detection = max(cube_only_dets["Phase"].values) inds_to_drop_nondets_between_dets = cube.loc[ (cube["Phase"] > first_detection) & (cube["Phase"] < last_detection) & (cube["Nondetection"] == True) ].index cube = cube.drop(inds_to_drop_nondets_between_dets).reset_index( drop=True ) try: cube["MagFromPeak"] = sn.info["peak_mag"] - cube["Mag"] sn.cube = cube except: pass
[docs] @staticmethod def interpolate_grid(grid, interp_array, filter_window=171): """ Function to remove NaNs by interpolating between actual measurements in a phase/wl grid Takes as input a grid to interpolate over, an array containing values of the grid along the interpolation dimension, as well as a filter window for smoothing using a Savitsky-Golay filter """ for i, row in enumerate(grid): notnan_inds = np.where((~np.isnan(row)))[0] if len(notnan_inds) > 4: interp = interp1d(interp_array[notnan_inds], row[notnan_inds], "linear") interp_row = interp(interp_array[min(notnan_inds) : max(notnan_inds)]) savgol_l = savgol_filter(interp_row, filter_window, 3, mode="mirror") row[min(notnan_inds) : max(notnan_inds)] = savgol_l grid[i] = row return grid
def _build_samples( self, filt, sn_set=None, ): """ Builds the data set from the SN collection for a given filter and returns, along with the phases, wls, and mags, the uncertainty in the measurements as the standard deviation of the photometry at each phase step """ phases, mags, errs, wls = super()._process_dataset( filt, sn_set=sn_set, ) if len(phases) == 0: return np.asarray([]), np.asarray([]), np.asarray([]), np.asarray([]) min_phase, max_phase = sorted(phases)[0], sorted(phases)[-1] phase_grid = np.linspace(min_phase, max_phase, len(phases)) phase_grid_space = (max_phase - min_phase) / len(phases) err_grid = np.ones(len(phase_grid)) for mjd in phase_grid: ind = np.where( (phases < mjd + phase_grid_space / 2) & (phases > mjd - phase_grid_space / 2) )[0] mags_at_this_phase = mags[ind] if len(mags_at_this_phase) > 1: std_mag = max(np.std(mags_at_this_phase), 0.01) elif len(mags_at_this_phase) == 1: std_mag = errs[ind] else: std_mag = 0.01 err_grid[ind] = std_mag err_grid = np.nan_to_num(err_grid, nan=0.01) return ( phases.astype(float), wls.astype(float), mags.astype(float), err_grid.astype(float), ) def _process_dataset(self, set_to_normalize=None): """ Processes the data set for the GP3D object's SN collection or (optionally) a SN set filter-by-filter and returns dataframes of the SN collection's photometric details or the photometric details of the SN set to normalize to """ ### Create the template grid from the observations if set_to_normalize is not None: ( all_template_phases, all_template_wls, all_template_mags, all_template_errs, ) = ([], [], [], []) for filt in self.filtlist: phases, wl_grid, mags, err_grid = self._build_samples( filt, sn_set=set_to_normalize, ) all_template_phases = np.concatenate( (all_template_phases, phases.flatten()) ) all_template_wls = np.concatenate((all_template_wls, wl_grid.flatten())) all_template_mags = np.concatenate((all_template_mags, mags.flatten())) all_template_errs = np.concatenate( (all_template_errs, err_grid.flatten()) ) else: # Create grid from the SN collection instead all_phases, all_wls, all_mags, all_errs = [], [], [], [] for filt in self.filtlist: phases, wl_grid, mags, err_grid = self._build_samples(filt) all_phases = np.concatenate((all_phases, phases.flatten())) all_wls = np.concatenate((all_wls, wl_grid.flatten())) all_mags = np.concatenate((all_mags, mags.flatten())) all_errs = np.concatenate((all_errs, err_grid.flatten())) all_template_phases = all_phases all_template_wls = all_wls all_template_mags = all_mags all_template_errs = all_errs template_df = pd.DataFrame( { "Phase": all_template_phases, "Wavelength": all_template_wls, "Mag": all_template_mags, "MagErr": all_template_errs, } ) return template_df def _construct_median_grid( self, phasemin, phasemax, filtlist, template_df, log_transform, plot=False, ): """ Takes as input the photometry from the sn set to normalize and constructs a 2D template grid consisting of the median photometry at each phase and wl step """ phase_grid_linear = np.arange( phasemin, phasemax, 1 / 24.0 ) # Grid of phases to iterate over, by hour phase_grid = np.log( phase_grid_linear + log_transform ) # Grid of phases in log space wl_grid_linear = np.arange( min(self.wle[f] for f in filtlist) - 500, max(self.wle[f] for f in filtlist) + 500, 99.5, ) # Grid of wavelengths to iterate over, by 100 A wl_grid = np.log10(wl_grid_linear) mag_grid = np.empty((len(phase_grid), len(wl_grid))) mag_grid[:] = np.nan err_grid = np.copy(mag_grid) for i in range(len(phase_grid)): for j in range(len(wl_grid)): ### Get all data that falls within this phase + 5 days, and this wl +- 100 A inds = template_df[ ( np.exp(template_df["Phase"]) - np.exp(phase_grid[i]) <= np.log(5.0) ) & (np.exp(template_df["Phase"]) - np.exp(phase_grid[i] > 0.0)) & (abs(10 ** template_df["Wavelength"] - 10 ** wl_grid[j]) <= 500) ].index if len(inds) > 0: median_mag = np.median(template_df["Mag"][inds].values) iqr = np.subtract( *np.percentile(template_df["Mag"][inds], [75, 25]) ) mag_grid[i, j] = median_mag err_grid[i, j] = iqr mag_grid = self.interpolate_grid(mag_grid.T, phase_grid) mag_grid = mag_grid.T mag_grid = self.interpolate_grid(mag_grid, wl_grid, filter_window=31) err_grid = self.interpolate_grid(err_grid.T, phase_grid) err_grid = err_grid.T err_grid = self.interpolate_grid(err_grid, wl_grid, filter_window=31) X, Y = np.meshgrid(np.exp(phase_grid) - log_transform, 10**wl_grid) Z = mag_grid.T if plot: Plot().plot_construct_grid( gp_class=self, X=X, Y=Y, Z=Z, phase_grid=phase_grid, mag_grid=mag_grid, wl_grid=wl_grid, err_grid=err_grid, filtlist=filtlist, grid_type="median", ) return phase_grid, wl_grid, mag_grid, err_grid def _construct_polynomial_grid( self, phasemin, phasemax, filtlist, template_df, log_transform, plot=False, ): """ Takes as input the photometry from the sn set to normalize and constructs a 2D template grid consisting of the polynomial fit to the SN set to normalize photometry at each phase and wl step """ phase_grid_linear = np.arange( phasemin, phasemax, 1 / 24.0 ) # Grid of phases to iterate over, by hour phase_grid = np.log( phase_grid_linear + log_transform ) # Grid of phases in log space wl_grid = np.arange( np.log10(min(self.wle[f] for f in filtlist) - 500), np.log10(max(self.wle[f] for f in filtlist) + 500), 0.01, ) mag_grid = np.empty((len(phase_grid), len(wl_grid))) mag_grid[:] = np.nan err_grid = np.copy(mag_grid) ### Add an array of fake measurements to anchor the ends of the fit anchor_inds_begin = np.where( (abs(template_df["Phase"] - np.log(phasemin + log_transform)) < 1.0) )[0] if len(anchor_inds_begin) > 0: anchor_mag_begin = min(template_df["Mag"][anchor_inds_begin].values) else: anchor_mag_begin = -4.0 anchor_inds_end = np.where( (abs(template_df["Phase"] - np.log(phasemax + log_transform)) < 0.3) )[0] if len(anchor_inds_end) > 0: anchor_mag_end = min(template_df["Mag"][anchor_inds_end].values) else: anchor_mag_end = -4.0 anchor_phases = np.asarray( [ np.log(phasemin + log_transform), np.log(phasemin + 2.5 + log_transform), np.log(phasemax + log_transform), ] ) anchor_mags = np.asarray( [anchor_mag_begin - 1.0, anchor_mag_begin, anchor_mag_end - 1.0] ) for j in range(len(wl_grid)): ### Get all data that falls within this wl +- 500 A inds = template_df[ abs(10 ** template_df["Wavelength"] - 10 ** wl_grid[j]) <= 499 ].index if len(inds) > 0: phases_to_fit = np.concatenate( (template_df["Phase"][inds], anchor_phases) ) mags_to_fit = np.concatenate((template_df["Mag"][inds], anchor_mags)) # if j == 0 or j == len(wl_grid) - 1: # # At the wavelength boundary, anchor this with artificially lower mags # mags_to_fit -= 2 errs_to_fit = np.concatenate( (template_df["MagErr"][inds], np.ones(len(anchor_phases)) * 0.05) ) fit_coeffs = np.polyfit( phases_to_fit, mags_to_fit, 3, w=1 / ( np.sqrt( errs_to_fit**2 + ( np.ones( len(template_df["MagErr"][inds]) + len(anchor_phases) ) * 0.1 ) ** 2 ) ), ) fit = np.poly1d(fit_coeffs) grid_mags = fit(phase_grid) mag_grid[:, j] = grid_mags err_grid[:, j] = np.ones(len(phase_grid)) * np.median( abs(template_df["Mag"][inds] - fit(template_df["Phase"][inds])) ) ### Interpolate over the wavelengths to get a complete 2D grid mag_grid = self.interpolate_grid(mag_grid, wl_grid, filter_window=31) err_grid = self.interpolate_grid(err_grid, wl_grid, filter_window=31) X, Y = np.meshgrid(np.exp(phase_grid) - log_transform, 10**wl_grid) Z = mag_grid.T if plot: Plot().plot_construct_grid( gp_class=self, X=X, Y=Y, Z=Z, grid_type="polynomial" ) return phase_grid, wl_grid, mag_grid, err_grid def _subtract_data_from_grid( self, sn, filtlist, phase_grid, wl_grid, mag_grid, err_grid, plot=False, ): """ Takes the (shifted) photometry from a given SN and subtracts from it the template grid constructed from either the median of all the normalization SN photometry, or the polynomial fit to all the normalization SN photometry Returns the phase and wavelength of each data point from the given SN, as well as the residuals in its magnitude (or flux) and its uncertainty """ ### Subtract off templates for each SN LC residuals = [] for filt in filtlist: if filt in sn.cube["ShiftedFilter"].values: mags = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][ "ShiftedFlux" ].values errs = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][ "ShiftedFluxerr" ].values current_nondets = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][ "Nondetection" ].values current_wls = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][ "LogShiftedWavelength" ].values mags_from_peak = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][ "MagFromPeak" ].values phases = sn.cube.loc[sn.cube["ShiftedFilter"] == filt][ "LogPhase" ].values else: phases = [] # There's a bug that I have to track down where sometimes filters without data # have a NaN as their redshift until data is read in, so the second clause # of this if statement will skip over those filters # (I don't know why this isn't caught above) if len(phases) > 0 and not np.isnan(sn.info.get("z", 0)): if plot: _, ax = plt.subplots() for i, phase in enumerate(phases): ### Get index of current phase in phase grid ### The phase corresponding to phase_ind is no more than the phase grid spacing away from the true phase being measured phase_ind = np.argmin(abs(np.exp(phase_grid) - np.exp(phase))) wl_ind = np.argmin(abs(wl_grid - current_wls[i])) if np.isnan(mag_grid[phase_ind, wl_ind]): logger.warning( f"NaN Found: phase {np.exp(phase)}, wl {10 ** wl_grid[wl_ind]}" ) continue if np.isinf(mags[i] - mag_grid[phase_ind, wl_ind]): logger.warning( f"Infinity found: phase {np.exp(phase)}, wl {10 ** wl_grid[wl_ind]}" ) continue residuals.append( { "Filter": filt, "Phase": phase, "Wavelength": current_wls[i], "MagResidual": mags[i] - mag_grid[phase_ind, wl_ind], "MagErr": errs[i], "Mag": mags_from_peak[i], "Nondetection": current_nondets[i], } ) if plot: plt.errorbar( phase, mags[i] - mag_grid[phase_ind, wl_ind], yerr=np.sqrt( errs[i] ** 2 + err_grid[phase_ind, wl_ind] ** 2 ), marker="o", color="k", ) plt.errorbar( phase, mags[i], yerr=errs[i], fmt="o", color=colors.get(filt, "k"), ) if plot: # NOTE: The wl ind changes with each data point, # therefore this is representation is not fully accurate Plot().plot_subtract_data_from_grid( sn_class=sn, phase_grid=phase_grid, mag_grid=mag_grid, wl_ind=wl_ind, filt=filt, ax=ax, ) plt.show() return pd.DataFrame(residuals) def _build_test_wavelength_phase_grid_from_photometry( self, measured_wavelengths, measured_phases, wl_grid, phase_grid ): """ Function to build a uniform grid of wavelengths and phases given photometry in the form of measured wavelengths and phases as well as a wavelength and phase grid corresponding to the template SED grid """ waves_to_predict = np.unique(measured_wavelengths) diffs = abs( np.subtract.outer(10**wl_grid, 10**waves_to_predict) ) # The difference between our measurement wavelengths and the wl grid phases_to_predict = np.unique(measured_phases) ### Compare the wavelengths of our measured filters to those in the wl grid ### and fit for those grid wls that are within 500 A of one of our measurements wl_inds_fitted = np.unique(np.where((diffs < 500.0))[0]) phase_inds_fitted = np.unique( np.where( (phase_grid >= min(phases_to_predict)) & (phase_grid <= max(phases_to_predict)) )[0] ) if len(phase_inds_fitted) == 0: return [], [], [], [], None linear_phases = np.exp(phase_grid[phase_inds_fitted]) - self.log_transform phases = np.log(linear_phases - min(linear_phases) + 0.1) x, y = np.meshgrid(phases, wl_grid[wl_inds_fitted]) return x, y, wl_inds_fitted, phase_inds_fitted, min(linear_phases)
[docs] def run_gp_on_full_sample( self, plot=False, subtract_median=False, subtract_polynomial=False, ): """ Run the Gaussian Process Regression fitting routine on the full sample at once. Does not individually fit each transient, but instead fits all photometry in the input sample simultaneously. Args: plot (bool, optional): Show intermediate plots. Defaults to False. subtract_median (bool, optional): Fit and subtract off a median function to calculate residuals. Defaults to False. subtract_polynomial (bool, optional): Fit and subtract off a polynomial function to calculate residuals. One of `subtract_median` and `subtract_polynomial` must be True. Defaults to False. Raises: Exception: Must toggle either subtract_median or subtract_polynomial as True to run GP3D. Returns: list: A list containing the Gaussian process model and its uncertainty, both as an mxn array np.ndarray: An array containing the processed magnitudes from the template grid np.ndarray: An array containing the log-transformed phases np.ndarray: An array containing the log-transformed wavelengths """ template_df = self._process_dataset(set_to_normalize=self.set_to_normalize) if subtract_polynomial: phase_grid, wl_grid, mag_grid, err_grid = self._construct_polynomial_grid( self.phasemin, self.phasemax, self.filtlist, template_df, log_transform=self.log_transform, plot=plot, ) elif subtract_median: phase_grid, wl_grid, mag_grid, err_grid = self._construct_median_grid( self.phasemin, self.phasemax, self.filtlist, template_df, log_transform=self.log_transform, plot=plot, ) else: raise Exception( "Must toggle either subtract_median or subtract_polynomial as True to run GP3D" ) x_input, y_input, err, raw_mags = [], [], [], [] for sn in self.collection.sne: if hasattr(sn, "cube"): residuals = self._subtract_data_from_grid( sn, self.filtlist, phase_grid, wl_grid, mag_grid, err_grid, plot=False, # TODO: are we sure we want to hard-code False for grid subtraction? ) if len(residuals) == 0: continue phase_residuals_linear = ( np.exp(residuals["Phase"].values) - self.log_transform ) phases_to_fit = np.log( phase_residuals_linear - min(phase_residuals_linear) + 0.1 ) if not len(x_input): x_input = np.vstack( (phases_to_fit, residuals["Wavelength"].values) ).T else: x_input = np.concatenate( [ x_input, np.vstack( (phases_to_fit, residuals["Wavelength"].values) ).T, ] ) y_input = np.concatenate([y_input, residuals["MagResidual"].values]) raw_mags = np.concatenate([raw_mags, residuals["Mag"].values]) err = np.concatenate([err, residuals["MagErr"].values]) if isinstance(self.kernel, Kernel): gaussian_process = GaussianProcessRegressor( kernel=self.kernel.kernel, alpha=err, n_restarts_optimizer=10 ) else: gaussian_process = GaussianProcessRegressor( kernel=self.kernel, alpha=err, n_restarts_optimizer=10 ) gaussian_process.fit(x_input, y_input) x, y, wl_inds_fitted, phase_inds_fitted, phase_offset = ( self._build_test_wavelength_phase_grid_from_photometry( x_input[:, 1], x_input[:, 0], wl_grid, phase_grid ) ) test_prediction, std_prediction = gaussian_process.predict( np.vstack((x.ravel(), y.ravel())).T, return_std=True ) template_mags = [] for wl_ind in wl_inds_fitted: for phase_ind in phase_inds_fitted: template_mags.append(mag_grid[phase_ind, wl_ind]) template_mags = np.asarray(template_mags).reshape((len(x), -1)) final_prediction = test_prediction.reshape((len(x), -1)) + template_mags final_std_prediction = std_prediction.reshape((len(x), -1)) + template_mags ### Convert to mags from peak for i, col in enumerate(final_prediction[:,]): wl = 10 ** wl_grid[wl_inds_fitted][i] zp = (10**-23 * 3e18 / wl) * 1e11 shifted_mags = -1 * ( (col + np.log10(zp * 1e-11) - np.log10(10**-23 * 3e18 / 5000)) / -0.4 ) final_prediction[i] = shifted_mags ### Map predicted SED surface and uncertainty to the phase, wl grids gp_grid = np.empty((len(wl_grid), len(phase_grid))) gp_grid[:] = np.nan for i, col in enumerate(final_prediction[:,]): current_wl_grid_ind = wl_inds_fitted[i] for j in range(len(col)): current_phase_grid_ind = phase_inds_fitted[j] gp_grid[current_wl_grid_ind, current_phase_grid_ind] = col[j] gp_grid_std = np.empty((len(wl_grid), len(phase_grid))) gp_grid_std[:] = np.nan for i, col in enumerate(final_std_prediction[:,]): current_wl_grid_ind = wl_inds_fitted[i] for j in range(len(col)): current_phase_grid_ind = phase_inds_fitted[j] gp_grid_std[current_wl_grid_ind, current_phase_grid_ind] = col[j] fig = plt.figure() ax = fig.add_subplot(111, projection="3d") Plot().plot_construct_grid( gp_class=self, X=np.exp(x) + phase_offset - 0.1, Y=10 ** (y), Z=final_prediction, ax=ax, Z_lower=final_prediction - 1.96 * final_std_prediction, Z_upper=final_prediction + 1.96 * final_std_prediction, grid_type="final", ) ax.scatter( np.exp(x_input[:, 0]) - self.log_transform, 10 ** x_input[:, 1], raw_mags, marker="o", color="k", ) plt.show() return [gp_grid, gp_grid_std], mag_grid, phase_grid, wl_grid
[docs] def optimize_hyperparams(self, subtract_median=False, subtract_polynomial=False): """ Optimize the Gaussian Process Regression kernel hyperparameters by fitting each transient in the sample individually. This is normally run before the `predict` method to retrieve the optimized kernel hyperparameters for the input sample. These hyperparameters are normally then fixed in the kernel, and the `predict` method is run using the fixed kernel. Args: subtract_median (bool, optional): Fit and subtract off a median function to calculate residuals. Defaults to False. subtract_polynomial (bool, optional): Fit and subtract off a polynomial function to calculate residuals. One of `subtract_median` and `subtract_polynomial` must be True. Defaults to False. Raises: Exception: Must toggle either subtract_median or subtract_polynomial as True to run GP3D. Returns: list: List of optimized kernel hyperparameters. """ template_df = self._process_dataset(set_to_normalize=self.set_to_normalize) kernel_params = [] if subtract_polynomial: phase_grid, wl_grid, mag_grid, err_grid = self._construct_polynomial_grid( self.phasemin, self.phasemax, self.filtlist, template_df, log_transform=self.log_transform, plot=False, ) elif subtract_median: phase_grid, wl_grid, mag_grid, err_grid = self._construct_median_grid( self.phasemin, self.phasemax, self.filtlist, template_df, log_transform=self.log_transform, plot=False, ) else: raise Exception( "Must toggle either subtract_median or subtract_polynomial as True to run GP3D" ) for sn in self.collection.sne: if hasattr(sn, "cube"): residuals = self._subtract_data_from_grid( sn, self.filtlist, phase_grid, wl_grid, mag_grid, err_grid, plot=False, # TODO: are we sure we want to hard-code False for grid subtraction? ) if len(residuals) == 0: continue phase_residuals_linear = ( np.exp(residuals["Phase"].values) - self.log_transform ) phases_to_fit = np.log( phase_residuals_linear - min(phase_residuals_linear) + 0.1 ) x = np.vstack((phases_to_fit, residuals["Wavelength"].values)).T y = residuals["MagResidual"].values if len(y) > 1: # We have enough points to fit err = residuals["MagErr"].values if isinstance(self.kernel, Kernel): gaussian_process = GaussianProcessRegressor( kernel=self.kernel.kernel, alpha=err, n_restarts_optimizer=10, ) else: gaussian_process = GaussianProcessRegressor( kernel=self.kernel, alpha=err, n_restarts_optimizer=10 ) gaussian_process.fit(x, y) kernel_params.append(gaussian_process.kernel_.theta) optimized_kernel_hyperparams = np.asarray( [ np.median([k[i] for k in kernel_params]) for i in range(len(kernel_params[0])) ] ) if isinstance(self.kernel, Kernel): self.kernel.recursively_set_params(optimized_kernel_hyperparams, "fixed") else: self.kernel.set_params( **{ "length_scale": optimized_kernel_hyperparams, "length_scale_bounds": "fixed", } ) return kernel_params
def _iteratively_warp_sed( self, residuals: pd.DataFrame, test_prediction_reshaped: np.ndarray, wls_fit: np.ndarray, phases_fit: np.ndarray, sn: SN, convergence_threshold: float = 1.0, niter: float = 100, ): """ Compare photometry to prediction, iteratively warping the predicted flux until it matches the photometry within convergence_threshold * error for each data point """ phases_to_iterate_over = np.log( np.arange( min(np.exp(residuals["Phase"].values) - self.log_transform), max(np.exp(residuals["Phase"].values) - self.log_transform), 0.5, ) ) for phase in phases_to_iterate_over: current_lc_inds = np.where((abs(residuals["Phase"] - phase) <= 0.1))[0] if ( len(current_lc_inds) > 0 and len({filt for filt in residuals["Filter"][current_lc_inds]}) > 1 ): # We have data in at least two filters at this epoch current_lc_cube = residuals[abs(residuals["Phase"] - phase) <= 0.1] current_lc = current_lc_cube["Mag"].values current_lc = np.concatenate(([0.0], current_lc, [current_lc[-1] / 2])) current_lc_errs = np.concatenate( ([0.0], current_lc_cube["MagErr"].values, [0.0]) ) errors = np.ones(len(current_lc_inds)) * 100.0 errors = np.concatenate(([1.0], errors, [1.0])) wls = [ np.log10(self.wle[filt] * (1 + sn.info.get("z", 0))) for filt in current_lc_cube["Filter"].values ] wls = np.concatenate( ( [np.log10(10 ** min(wls) - 500)], wls, [np.log10(10 ** max(wls) + 500)], ) ) sed_inds_at_this_phase = np.where( (wls_fit > min(wls)) & (wls_fit < max(wls)) )[0] if len(sed_inds_at_this_phase) == 0: ### Redshifted filter central wavelengths outside the bounds of the filters fit by the GP continue phase_ind = np.argmin(abs(phases_fit - phase)) prediction_slice = np.copy( test_prediction_reshaped[sed_inds_at_this_phase, phase_ind] ) n = 0 for i in range(niter): if ( all( [ abs(1 - errors[i]) <= convergence_threshold * current_lc_errs[i] for i in range(len(errors)) ] ) or n == niter ): break for j, filt in enumerate(current_lc_cube["Filter"]): real_flux_inds = ( np.where((current_lc_cube["Filter"] == filt))[0] + 1 ) if len(real_flux_inds) > 1: real_mag = np.average(current_lc[real_flux_inds]) else: real_mag = current_lc[real_flux_inds][0] wl_for_filt = np.log10( self.wle[filt] * (1 + sn.info.get("z", 0)) ) wl_ind = np.argmin( abs(wls_fit[sed_inds_at_this_phase] - wl_for_filt) ) predicted_mag = prediction_slice[wl_ind] diff = real_mag - predicted_mag if predicted_mag > 0: error = 1 + diff else: error = 1 - diff errors[j + 1] = error if any( [ abs(1 - errors[j]) > convergence_threshold * current_lc_errs[j] for j in range(len(errors)) ] ): ### Warp the SED by an interpolation of the error across wavelength ### and rerun the loop if n < niter: n += 1 if any(errors > 1e3) or any(np.isinf(errors)): n = niter interp_error_fn = interp1d(wls, errors) interp_to_warp_by = interp_error_fn( wls_fit[sed_inds_at_this_phase] ) prediction_slice *= interp_to_warp_by if n == niter: logger.debug( f"Couldnt iterate to match flux: {sn.name}, {np.exp(phase) - self.log_transform}" ) if n < niter: # Put the warped SED back in the predicted datacube test_prediction_reshaped[sed_inds_at_this_phase, phase_ind] = ( prediction_slice ) return test_prediction_reshaped def _sample_predicted_sed(self, mean_prediction, std_prediction): """Randomly sample the predicted SED for better uncertainty estimate in final GP surface""" ### Draw random sigma between -1 and 1 sigma = truncnorm.rvs(-1, 1, loc=0, scale=1, size=1)[0] sampled_prediction = mean_prediction + (sigma * std_prediction) return sampled_prediction def _smooth_predicted_model( self, model_array: np.ndarray, window_size: int, transpose: bool = False ): if transpose: model_array = model_array.T model_array_smoothed = np.empty(model_array.shape) for i, col in enumerate(model_array): if window_size % 2 == 0: window_size += 1 # Must be odd # Use astropy convolve function to handle NaNs model_array_smoothed[i, :] = convolve( col, np.ones(window_size) / window_size, boundary="extend" ) # Boxcar smoothing if transpose: model_array_smoothed = model_array_smoothed.T return model_array_smoothed
[docs] def run_gp_individually( self, plot=False, subtract_median=False, subtract_polynomial=False, interactive=False, run_diagnostics=False, save_individual_fits=False, ): """ Run the Gaussian Process Regression fitting routine on each transient in the full sample individually. This produces a bespoke Gaussian process model for each transient, constructing a full 3-dimension SED surface. These surfaces are then randomly sampled and to be used in the construction of the final template model surface. Args: plot (bool, optional): Show intermediate plots. Defaults to False. subtract_median (bool, optional): Fit and subtract off a median function to calculate residuals. Defaults to False. subtract_polynomial (bool, optional): Fit and subtract off a polynomial function to calculate residuals. One of `subtract_median` and `subtract_polynomial` must be True. Defaults to False. interactive (bool, optional): Interactively choose which fits to use in the creation of the final Gaussian process model. If True, sets `plot` to True as well. Defaults to False. run_diagnostics (bool, optional): Run diagnostic tests on the fitting to identify data points or regions of poor fit quality. Defaults to False. save_individual_fits (bool, optional): Save the GP fits to each individual SN object separately. If True, will save the models to the default location. WARNING: This can be a lot of data. Defualts to False. Raises: Exception: Must toggle either subtract_median or subtract_polynomial as True to run GP3D. Returns: list: A list of random samples from the Gaussian process distribution. Used to construct final 3-dimensional templates. np.ndarray: An array containing the processed magnitudes from the template grid np.ndarray: An array containing the log-transformed phases np.ndarray: An array containing the log-transformed wavelengths """ if interactive: plot = True template_df = self._process_dataset(set_to_normalize=self.set_to_normalize) gaussian_processes = [] if subtract_polynomial: phase_grid, wl_grid, mag_grid, err_grid = self._construct_polynomial_grid( self.phasemin, self.phasemax, self.filtlist, template_df, log_transform=self.log_transform, plot=plot, ) elif subtract_median: phase_grid, wl_grid, mag_grid, err_grid = self._construct_median_grid( self.phasemin, self.phasemax, self.filtlist, template_df, log_transform=self.log_transform, plot=plot, ) else: raise Exception( "Must toggle either subtract_median or subtract_polynomial as True to run GP3D" ) for sn in self.collection.sne: if not hasattr(sn, "cube"): continue residuals = self._subtract_data_from_grid( sn, self.filtlist, phase_grid, wl_grid, mag_grid, err_grid, plot=False, # TODO: are we sure we want to hard-code False for grid subtraction? ) if len(residuals) == 0: continue phase_residuals_linear = ( np.exp(residuals["Phase"].values) - self.log_transform ) phases_to_fit = np.log( phase_residuals_linear - min(phase_residuals_linear) + 0.1 ) x = np.vstack((phases_to_fit, residuals["Wavelength"].values)).T y = residuals["MagResidual"].values if len(y) > 1: # We have enough points to fit err = residuals["MagErr"].values if isinstance(self.kernel, Kernel): gaussian_process = GaussianProcessRegressor( kernel=self.kernel.kernel, alpha=err, n_restarts_optimizer=10 ) else: gaussian_process = GaussianProcessRegressor( kernel=self.kernel, alpha=err, n_restarts_optimizer=10 ) gaussian_process.fit(x, y) if plot: _, ax = Plot().create_empty_subplot() filts_fitted = [] for filt in self.filtlist: shifted_mjd = sn.cube[sn.cube["Filter"] == filt][ "Phase" ].values.astype(float) shifted_mjd = sn.log_transform_time( shifted_mjd, phase_start=self.log_transform ) inds_to_fit = np.where( (shifted_mjd > np.log(self.phasemin + self.log_transform)) & (shifted_mjd < np.log(self.phasemax + self.log_transform)) )[0] if len(inds_to_fit) > 0: filts_fitted.append(filt) if plot or run_diagnostics: test_times_linear = np.arange( min(phase_residuals_linear), max(phase_residuals_linear), 1.0 / 24, ) test_times = np.log( test_times_linear - min(phase_residuals_linear) + 0.1 ) test_waves = np.ones(len(test_times)) * np.log10( self.wle[filt] * (1 + sn.info.get("z", 0)) ) ### Trying to convert back to normalized magnitudes here wl_ind = np.argmin( abs( 10**wl_grid - self.wle[filt] * (1 + sn.info.get("z", 0)) ) ) template_mags = [] for i in range(len(test_times_linear)): j = np.argmin( abs( np.exp(phase_grid) - self.log_transform - test_times_linear[i] ) ) template_mags.append(mag_grid[j, wl_ind]) template_mags = np.asarray(template_mags) test_prediction, std_prediction = gaussian_process.predict( np.vstack((test_times, test_waves)).T, return_std=True ) test_times = ( np.exp(test_times) + min(phase_residuals_linear) - 0.1 ) residuals_for_filt = residuals[ (residuals["Filter"] == filt) & ( residuals["Phase"] > np.log(self.phasemin + self.log_transform) ) & ( residuals["Phase"] < np.log(self.phasemax + self.log_transform) ) & (residuals["Nondetection"] == False) ] if len(inds_to_fit) > 0: if plot: Plot().plot_run_gp_overlay( ax=ax, test_times=test_times, test_prediction=test_prediction, std_prediction=std_prediction, template_mags=template_mags, residuals=residuals_for_filt, log_transform=self.log_transform, filt=filt, sn=sn, ) if run_diagnostics: d = Diagnostic() d.identify_outlier_points( filt, test_times, test_prediction + template_mags, std_prediction, np.exp(residuals_for_filt["Phase"].values) - self.log_transform, residuals_for_filt["Mag"].values, residuals_for_filt["MagErr"].values, ) d.check_late_time_slope( filt, test_times, test_prediction + template_mags, np.exp(residuals_for_filt["Phase"].values) - self.log_transform, ) if interactive: use_for_template = input( "Use this fit to construct a template? y/n" ) x, y, wl_inds_fitted, phase_inds_fitted, phase_offset = ( self._build_test_wavelength_phase_grid_from_photometry( residuals["Wavelength"].values, residuals["Phase"].values, wl_grid, phase_grid, ) ) if len(x) == 0: continue try: test_prediction, std_prediction = gaussian_process.predict( np.vstack((x.ravel(), y.ravel())).T, return_std=True ) except Exception as e: logger.warning(f"WARNING: BROKEN FIT FOR {sn.name}", exc_info=e) continue test_prediction = np.asarray(test_prediction) template_mags = [] for wl_ind in wl_inds_fitted: for phase_ind in phase_inds_fitted: template_mags.append(mag_grid[phase_ind, wl_ind]) ###NOTE: Some of these template mags are NaNs template_mags = np.asarray(template_mags).reshape((len(x), -1)) ### Put the fitted wavelengths back in the right spot on the grid ### and append to the gaussian processes array test_prediction_reshaped = ( test_prediction.reshape((len(x), -1)) + template_mags ) ### Convert to mags from peak for i, col in enumerate(test_prediction_reshaped[:,]): wl = 10 ** wl_grid[wl_inds_fitted][i] zp = (10**-23 * 3e18 / wl) * 1e11 shifted_peak_mag = np.log10( sn.zps[sn.info["peak_filt"]] * 1e-11 * 10 ** (-0.4 * sn.info["peak_mag"]) ) shifted_mags = -1 * ( (np.log10(10 ** (col + shifted_peak_mag) / (zp * 1e-11)) / -0.4) - sn.info["peak_mag"] ) test_prediction_reshaped[i] = shifted_mags test_prediction_reshaped = self._iteratively_warp_sed( residuals, test_prediction_reshaped, wl_grid[wl_inds_fitted], phase_grid[phase_inds_fitted], sn, convergence_threshold=1.0, ) std_prediction_reshaped = ( std_prediction.reshape((len(x), -1)) + template_mags ) test_prediction_smoothed = self._smooth_predicted_model( test_prediction_reshaped, window_size=max( int( round( test_prediction_reshaped.shape[0] / (2 * len(filts_fitted)), 0, ) ), 5, ), # Window size of approximately half a filter length scale transpose=True, ) test_prediction_smoothed = self._smooth_predicted_model( test_prediction_smoothed, window_size=max( int( round( test_prediction_smoothed.shape[1] / (len(phase_grid / (5 * 24))), 0, ) ), 5, ), # Window size of approximate a day ) std_prediction_smoothed = self._smooth_predicted_model( std_prediction_reshaped, window_size=max( int( round( std_prediction_reshaped.shape[0] / (2 * len(filts_fitted)), 0, ) ), 5, ), transpose=True, ) gp_grid = np.empty((len(wl_grid), len(phase_grid))) gp_grid[:] = np.nan for i, col in enumerate(test_prediction_smoothed[:,]): # for i, col in enumerate(test_prediction_reshaped[:,]): current_wl_grid_ind = wl_inds_fitted[i] for j in range(len(col)): current_phase_grid_ind = phase_inds_fitted[j] gp_grid[current_wl_grid_ind, current_phase_grid_ind] = col[j] gp_grid_std = np.empty((len(wl_grid), len(phase_grid))) gp_grid_std[:] = np.nan # for i, col in enumerate(std_prediction_reshaped[:,]): for i, col in enumerate(std_prediction_smoothed[:,]): current_wl_grid_ind = wl_inds_fitted[i] for j in range(len(col)): current_phase_grid_ind = phase_inds_fitted[j] gp_grid_std[current_wl_grid_ind, current_phase_grid_ind] = col[ j ] if plot: Plot().plot_run_gp_surface( gp_class=self, x=np.exp(x) + phase_offset - 0.1, y=10 ** (y), # test_prediction_reshaped=test_prediction_reshaped, test_prediction_reshaped=test_prediction_smoothed, ) if run_diagnostics: d = Diagnostic() d.check_gradient_between_filters( [self.wle[f] for f in filts_fitted], np.exp(phase_grid[phase_inds_fitted]) - self.log_transform, 10 ** (wl_grid[wl_inds_fitted]), # test_prediction_reshaped, test_prediction_smoothed, # std_prediction_reshaped, std_prediction_smoothed, [-15.0, 0.0, 50.0], ) if "UVM2" in filts_fitted: d.check_uvm2_flux( np.exp(phase_grid[phase_inds_fitted]) - self.log_transform, 10 ** (wl_grid[wl_inds_fitted]), # test_prediction_reshaped, test_prediction_smoothed, # std_prediction_reshaped, std_prediction_smoothed, [-15.0, 0.0, 50.0], ) if not interactive: use_for_template = "y" if use_for_template == "y": for i in range(round(np.log(len(residuals)))): random_sample = self._sample_predicted_sed(gp_grid, gp_grid_std) gaussian_processes.append(random_sample) if save_individual_fits: snmodel = SNModel( surface=gaussian_process, template_mags=template_mags.T, phase_grid=np.exp(phase_grid[phase_inds_fitted]) - self.log_transform, wl_grid=10**(wl_grid[wl_inds_fitted]), filters_fit=filts_fitted, sn=sn, norm_set=self.set_to_normalize, log_transform=self.log_transform, ) snmodel.save_fits() return gaussian_processes, mag_grid, phase_grid, wl_grid
[docs] def predict( self, plot=False, subtract_median=False, subtract_polynomial=False, run_diagnostics=False, fit_separately=True, ): """ Generate a Gaussian Process Regression model of the input transient sample. Uses the specified normalization sample and all other initialized parameters to process and fit the data to produce the model. Args: plot (bool, optional): Show intermediate plots. Defaults to False. subtract_median (bool, optional): Fit and subtract off a median function to calculate residuals. Defaults to False. subtract_polynomial (bool, optional): Fit and subtract off a polynomial function to calculate residuals. One of `subtract_median` and `subtract_polynomial` must be True. Defaults to False. run_diagnostics (bool, optional): Run diagnostic tests on the fitting to identify data points or regions of poor fit quality. Defaults to False. fit_separately (bol, optional): Fit each transient separately, or together as a group. Controls whether `run_gp_individually` or `run_gp_on_full_sample` is called to generate the predictive Gaussian Process Regression model. Defaults to True. Returns: SNModel: An SNModel object containing the final, 3-dimensional Gaussian Process Regression template model of the input transient sample. """ self._prepare_data() if not fit_separately: gaussian_process, template_mags, phase_grid, wl_grid = ( self.run_gp_on_full_sample( plot=plot, subtract_polynomial=subtract_polynomial, subtract_median=subtract_median, ) ) surface = SurfaceArray( surface=np.asarray(gaussian_process), phase_grid=phase_grid, wl_grid=wl_grid, kernel=self.kernel, ) snmodel = SNModel( phase_grid=np.exp(phase_grid) - self.log_transform, wl_grid=10**wl_grid, filters_fit=self.filtlist, surface=surface, template_mags=template_mags, sncollection=self.collection, norm_set=self.set_to_normalize, log_transform=self.log_transform, ) return snmodel else: ### We're fitting each SN individually and then median combining the full 2D GP gaussian_processes, mag_grid, phase_grid, wl_grid = ( self.run_gp_individually( plot=plot, subtract_median=subtract_median, subtract_polynomial=subtract_polynomial, run_diagnostics=run_diagnostics, ) ) X, Y = np.meshgrid(np.exp(phase_grid) - self.log_transform, 10**wl_grid) median_gp = np.nanmedian(np.dstack(gaussian_processes), -1) median_gp = self.interpolate_grid(median_gp.T, wl_grid, filter_window=31) for i, col in enumerate(median_gp.T): median_gp[:, i] = medfilt(col, kernel_size=51) median_gp = median_gp.T median_gp = self.interpolate_grid(median_gp, phase_grid, filter_window=171) for i, col in enumerate(median_gp): median_gp[i, :] = medfilt(col, kernel_size=51) iqr_grid = np.nanstd(np.dstack(gaussian_processes), -1) iqr_grid = self.interpolate_grid(iqr_grid.T, wl_grid, filter_window=31) for i, col in enumerate(iqr_grid.T): iqr_grid[:, i] = medfilt(col, kernel_size=51) iqr_grid = iqr_grid.T iqr_grid = self.interpolate_grid(iqr_grid, phase_grid, filter_window=171) for i, col in enumerate(iqr_grid): iqr_grid[i, :] = medfilt(col, kernel_size=51) Z = median_gp Plot().plot_construct_grid( gp_class=self, X=X, Y=Y, Z=Z, Z_lower=Z - iqr_grid, Z_upper=Z + iqr_grid, grid_type="final", ) plt.show() surface = SurfaceArray( surface=np.asarray([median_gp, iqr_grid]), phase_grid=phase_grid, wl_grid=wl_grid, kernel=self.kernel, ) snmodel = SNModel( surface=surface, template_mags=mag_grid, phase_grid=np.exp(phase_grid) - self.log_transform, wl_grid=10 ** (wl_grid), filters_fit=self.filtlist, sncollection=self.collection, norm_set=self.set_to_normalize, log_transform=self.log_transform, ) return snmodel