import logging
import matplotlib.pyplot as plt
import numpy as np
from caat.utils import bin_spec, query_svo_service
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
[docs]
class Diagnostic:
"""
Class to implement a number of diagnostics
to check quality of the GP fitting, including
but not limited to plots and metrics.
This is a helper class to provide validation and support
during the Gaussian Process Regression fitting process.
These methods are not intended to be called outside the routines
in the GP3D class.
"""
[docs]
def identify_outlier_points(
self,
filt,
gp_times,
gp_prediction,
gp_std_deviation,
phases,
mags,
errs,
nsigma=3,
):
"""
Check the goodness of fit by finding
any points more than N sigma away from
the GP fit for each filter
"""
outliers = []
for i, phase in enumerate(phases):
gp_ind = np.argmin(abs(gp_times - phase))
if abs(mags[i] - gp_prediction[gp_ind]) > nsigma * (
errs[i] + gp_std_deviation[gp_ind]
):
outliers.append(
{
"phase": round(phase, 2),
"mag": round(mags[i], 2),
"err": round(errs[i], 2),
}
)
if len(outliers) > 0:
logger.warning(
"WARNING: Outlier points identified for filter {}: {}".format(
filt, outliers
)
)
[docs]
def check_late_time_slope(
self,
filt,
gp_times,
gp_prediction,
phases,
):
"""
Verify that the late-time slope of the
GP fit for a given filter is negative,
as expected for late-time SN light curves
"""
if len(phases) < 2:
return
last_phase = np.sort(phases)[-1]
second_to_last_phase = np.sort(phases)[-2]
last_phase_ind = np.argmin(abs(gp_times - last_phase))
second_to_last_phase_ind = np.argmin(abs(gp_times - second_to_last_phase))
if last_phase > 0 and second_to_last_phase > 0:
### Check that the GP fit between the last two data points is decreasing in brightness
gp_slope = (
gp_prediction[last_phase_ind] - gp_prediction[second_to_last_phase_ind]
) / (last_phase - second_to_last_phase)
if gp_slope > 0.0:
logger.warning(
f"WARNING: Late-time slope of the GP is increasing for filter {filt}"
)
### Check that the GP fit extrapolation after the last data point is decreasing in brightness
gp_extrapolation = (gp_prediction[-1] - gp_prediction[last_phase_ind]) / (
gp_times[-1] - last_phase
)
if gp_extrapolation > 0.0:
logger.warning(
f"WARNING: GP extrapolation at late times is increasing for filter {filt}"
)
[docs]
def check_gradient_between_filters(
self, filt_wls, phase_grid, wl_grid, gp_grid, std_grid, phases_to_check
):
"""
Check that the gradient between adjacent filters
at representative time slices is smooth, i.e. free
from second-order (or higher) bumps and wiggles
"""
if len(filt_wls) < 2:
return
for phase in phases_to_check:
phase_ind = np.argmin(abs(phase_grid - phase))
sed = gp_grid[:, phase_ind]
std = std_grid[:, phase_ind]
plt.plot(wl_grid, sed, color="k")
plt.fill_between(wl_grid, sed - std, sed + std)
plt.title("SED at {} days".format(round(phase_grid[phase_ind], 0)))
plt.xlabel("Wavelength")
plt.ylabel("Flux Relative to Peak")
plt.show()
for i, wl in enumerate(filt_wls):
if i == len(filt_wls) - 1:
# Reached the last filter
break
blue_wl_ind = np.argmin(abs(wl_grid - wl))
red_wl_ind = np.argmin(abs(wl_grid - filt_wls[i + 1]))
filter_gradient = gp_grid[blue_wl_ind:red_wl_ind, phase_ind]
filter_gradient_std = std_grid[blue_wl_ind:red_wl_ind, phase_ind]
### Check smoothness of filter gradient
# Fit the gradient as a 1d function
fit = np.poly1d(
np.polyfit(wl_grid[blue_wl_ind:red_wl_ind], filter_gradient, 1)
)
bad_fit_inds = np.where(
(
abs(filter_gradient - fit(wl_grid[blue_wl_ind:red_wl_ind]))
> 3 * abs(filter_gradient_std)
)
)[0]
if len(bad_fit_inds) > 0:
logger.warning(
"WARNING: gradient between filters not smooth at wavelengths {}".format(
wl_grid[blue_wl_ind:red_wl_ind][bad_fit_inds]
)
)
plt.plot(
wl_grid[blue_wl_ind:red_wl_ind],
fit(wl_grid[blue_wl_ind:red_wl_ind]),
color="gray",
linestyle="--",
)
plt.plot(
wl_grid[blue_wl_ind:red_wl_ind], filter_gradient, color="k"
)
plt.fill_between(
wl_grid[blue_wl_ind:red_wl_ind],
filter_gradient - filter_gradient_std,
filter_gradient + filter_gradient_std,
)
plt.xlabel("Wavelength")
plt.ylabel("Flux Relative to Peak")
plt.show()
def check_uvm2_flux(self, phase_grid, wl_grid, gp_grid, std_grid, phases_to_check):
""" """
for phase in phases_to_check:
phase_ind = np.argmin(abs(phase_grid - phase))
sed = gp_grid[:, phase_ind]
trans_wl, trans_eff = query_svo_service("Swift", "UVM2")
trans_eff /= max(trans_eff)
### Bin the transmission curve to common resolution as SED
binned_trans_wl, binned_trans_eff = bin_spec(trans_wl, trans_eff, wl_grid)
### Find indices of the transmission curve where
### the transmission drops below 20%
tail_inds = np.where((binned_trans_eff < 0.2))[0]
### Get indices of SED that fall within the binned transmission curve
sed_inds = np.where(
(wl_grid >= binned_trans_wl[0]) & (wl_grid <= binned_trans_wl[-1])
)[0]
### Calculate flux in the transmission curve
### Here we're shifting the flux upward by the minimum value
### in the entire filter to avoid issues with summing positive
### and negative normalized fluxes
shifted_full_sed_flux = sed[sed_inds] - np.nanmin(sed[sed_inds])
flux = np.nansum(shifted_full_sed_flux * binned_trans_eff[sed_inds])
### Calculate flux in the tail
shifted_tail_sed_flux = sed[tail_inds] - np.nanmin(sed[sed_inds])
tail_flux = np.nansum(shifted_tail_sed_flux * binned_trans_eff[tail_inds])
plt.plot(
binned_trans_wl[sed_inds],
shifted_full_sed_flux,
color="black",
label="GP SED",
)
plt.scatter(
binned_trans_wl[tail_inds],
shifted_tail_sed_flux,
marker="o",
color="blue",
)
plt.plot(
binned_trans_wl,
binned_trans_eff,
color="red",
label="UVM2 Transmission Efficiency",
)
plt.xlabel("Wavelength")
plt.ylabel("Relative Flux / Transmission Efficiency")
plt.legend()
plt.show()
### Compare the two values--is tail flux > 20% of the total?
ratio = round(tail_flux / flux, 2)
if ratio > 0.2:
logger.warning(
f"WARNING: {ratio} percent of UVM2 flux falls in red tail"
)
else:
logger.info(
f"All good: Only {ratio} percent of UVM2 flux falls in red tail"
)