import logging
import warnings
import matplotlib.pyplot as plt
import numpy as np
from caat.utils import colors, convert_shifted_fluxes_to_shifted_mags
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
warnings.filterwarnings("ignore")
[docs]
class Plot:
"""
Plot class allowing any number of diagnostic / step-by-step plots to be created and displayed
across the SN, SNCollection, GP, and GP3D classes.
This is a helper class called in the different classes to provide optional plotting
support. These methods are not intended to be called outside of the routines in
the other classes.
"""
def create_empty_subplot(self):
fig, ax = plt.subplots()
return fig, ax
def plot_sn_data(
self,
sn_class,
data_to_plot,
filts_to_plot,
plot_fluxes=False,
):
fig, ax = plt.subplots()
for f in filts_to_plot:
for filt, mag_list in data_to_plot.items():
if f and f == filt:
if plot_fluxes:
mjds = np.asarray(
[
phot["mjd"]
for phot in mag_list
if not phot.get("nondetection", False)
]
)
fluxes = np.asarray(
[
phot["flux"]
for phot in mag_list
if not phot.get("nondetection", False)
]
)
errs = np.asarray(
[
phot["fluxerr"]
for phot in mag_list
if not phot.get("nondetection", False)
]
)
ax.errorbar(
mjds,
fluxes,
yerr=errs,
fmt="o",
mec="black",
color=colors.get(filt, "k"),
label=filt,
)
nondet_mjds = np.asarray(
[
phot["mjd"]
for phot in mag_list
if phot.get("nondetection", False)
]
)
nondet_fluxes = np.asarray(
[
phot["flux"]
for phot in mag_list
if phot.get("nondetection", False)
]
)
nondet_errs = np.asarray(
[
phot["fluxerr"]
for phot in mag_list
if phot.get("nondetection", False)
]
)
ax.errorbar(
nondet_mjds,
nondet_fluxes,
yerr=nondet_errs,
fmt="v",
alpha=0.5,
color=colors.get(filt, "k"),
)
else:
mjds = np.asarray(
[
phot["mjd"]
for phot in mag_list
if not phot.get("nondetection", False)
]
)
mags = np.asarray(
[
phot["mag"]
for phot in mag_list
if not phot.get("nondetection", False)
]
)
errs = np.asarray(
[
phot["err"]
for phot in mag_list
if not phot.get("nondetection", False)
]
)
ax.errorbar(
mjds,
mags,
yerr=errs,
fmt="o",
mec="black",
color=colors.get(filt, "k"),
label=filt,
)
if not plot_fluxes:
plt.gca().invert_yaxis()
plt.ylabel("Apparent Magnitude")
else:
plt.ylabel("Flux")
plt.legend()
plt.xlabel("MJD")
plt.title(sn_class.name)
plt.minorticks_on()
plt.show()
[docs]
def plot_fit_for_max(
self,
sn_class,
mjd_array,
mag_array,
err_array,
fit_mjds,
fit_mags,
fit_errs,
inds_to_fit,
):
"""
Takes as input arrays for MJD, mag, and err for a filter
as well as the guess for the MJD of maximum and an array
to shift the lightcurve over,
and returns estimates of the peak MJD and mag at peak
"""
fig, ax = plt.subplots()
ax.errorbar(mjd_array, mag_array, yerr=err_array, fmt="o", color="black")
ax.errorbar(
fit_mjds,
fit_mags,
yerr=fit_errs,
fmt="o",
color="blue",
label="Used in Fitting",
)
if len(mjd_array[inds_to_fit]) > 0:
plt.ylim(
min(mag_array[inds_to_fit]) - 0.5, max(mag_array[inds_to_fit]) + 0.5
)
plt.xlabel("MJD")
plt.ylabel("Apparent Magnitude")
plt.title(sn_class.name)
plt.legend()
# plt.gca().invert_yaxis()
# plt.show()
def plot_shift_to_max(self, sn_class, mjds, mags, errs, nondets, filt):
sn = sn_class
plt.errorbar(
mjds[np.where((nondets == False))[0]],
mags[np.where((nondets == False))[0]],
yerr=errs[np.where((nondets == False))[0]],
fmt="o",
mec="black",
color=colors.get(filt, "k"),
label=filt + "-band",
)
plt.scatter(
mjds[np.where((nondets == True))[0]],
mags[np.where((nondets == True))[0]],
marker="v",
color=colors.get(filt, "k"),
alpha=0.2,
)
plt.xlabel("Shifted Time [days]")
plt.ylabel("Shifted Magnitude")
plt.title(sn.name + "-Shifted Data")
plt.gca().invert_yaxis()
plt.legend()
plt.show()
[docs]
def plot_all_lcs(
self,
sn_class,
filts=["all"],
log_transform=False,
plot_fluxes=False,
ax=None,
show=True,
):
"""plot all light curves of given subtype/collection
can plot single, multiple or all bands"""
sne = sn_class.sne
logger.info(f"Plotting all {len(sne)} lightcurves in the collection")
if not ax:
fig, ax = plt.subplots()
if filts[0] is not "all":
filts_to_plot = filts
else:
filts_to_plot = colors.keys()
for i, f in enumerate(filts_to_plot):
for sn in sne:
mjds, mags, errs, nondets = sn.shift_to_max(f, shift_fluxes=plot_fluxes)
if len(mjds) > 0:
if log_transform is not False:
mjds = sn.log_transform_time(mjds, phase_start=log_transform)
if plot_fluxes:
nondet_inds = np.where((nondets == False))[0]
det_inds = np.where((nondets == True))[0]
ax.errorbar(
mjds[nondet_inds],
mags[nondet_inds],
yerr=errs[nondet_inds],
fmt="o",
mec="black",
color=colors.get(f, "k"),
)
ax.scatter(
mjds[det_inds],
mags[det_inds],
marker="v",
alpha=0.2,
color=colors.get(f, "k"),
)
else:
ax.errorbar(
mjds,
mags,
yerr=errs,
fmt="o",
mec="black",
color=colors.get(f, "k"),
)
ax.errorbar([], [], color=colors.get(f, "k"), label=f)
if show:
filtText = f + "\n"
plt.figtext(
0.95, 0.75 - (0.05 * i), filtText, fontsize=14, color=colors.get(f)
)
if log_transform is False:
ax.set_xlabel("Shifted Time [days]")
else:
ax.set_xlabel("Log(Shifted Time)")
if plot_fluxes:
ax.set_ylabel("Shifted Fluxes")
else:
ax.set_ylabel("Shifted Magnitudes")
plt.gca().invert_yaxis()
if show:
plt.title(
"Lightcurves for collection of {} objects\nType:{}, Subtype:{}".format(
len(sne), sn_class.type, sn_class.subtype
)
)
plt.show()
def plot_gp_predict_gp(
self,
phases,
mean_prediction,
std_prediction,
mags,
errs,
filt,
use_fluxes=False,
):
fig, ax = plt.subplots()
ax.plot(sorted(phases), mean_prediction, color="k", label="GP fit", zorder=10)
ax.errorbar(
phases,
mags.reshape(-1),
errs.reshape(-1),
fmt="o",
color=colors.get(filt, "k"),
alpha=0.2,
label=filt,
zorder=0,
)
ax.fill_between(
sorted(phases.ravel()),
mean_prediction - 1.96 * std_prediction,
mean_prediction + 1.96 * std_prediction,
alpha=0.5,
color="lightgray",
label="96\% Confidence Interval",
zorder=10,
)
plt.xlabel("Shifted Time [days]")
if use_fluxes:
plt.ylabel("Fluxes")
else:
plt.gca().invert_yaxis()
plt.ylabel("Shifted Magnitude")
plt.title("Single-Filter GP Fit")
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()
[docs]
def plot_construct_grid(
self,
gp_class,
X,
Y,
Z,
ax=None,
Z_lower=None,
Z_upper=None,
grid_type=None,
phase_grid=None,
mag_grid=None,
wl_grid=None,
err_grid=None,
filtlist=None,
):
"""
:input grid_type: takes str object 'median' or 'poly', default=None
"""
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot_surface(X, Y, Z)
if Z_lower is not None:
ax.plot_surface(X, Y, Z_lower, color="blue", alpha=0.2)
if Z_upper is not None:
ax.plot_surface(X, Y, Z_upper, color="blue", alpha=0.2)
ax.set_xlabel("Phase Grid")
ax.set_ylabel("Wavelengths [Angstroms]")
ax.set_zlabel("Flux")
if grid_type == "polynomial":
ax.set_title("Polynomial Grid / Templates")
elif grid_type == "median":
ax.set_title("Median Grid / Templates")
elif grid_type == "final":
ax.set_title("Final Median GP Fit")
else:
ax.set_title("Grid")
plt.tight_layout()
# plt.show()
def plot_subtract_data_from_grid(
self,
sn_class,
phase_grid,
mag_grid,
wl_ind,
filt,
ax=None,
):
sn = sn_class
if not ax:
fig, ax = plt.subplots()
ax.plot(
phase_grid,
mag_grid[:, wl_ind],
color=colors.get(filt, "k"),
label="template",
)
plt.axhline(y=0, linestyle="--", color="gray")
ax.errorbar(
[], [], yerr=[], marker="o", color="k", label="residuals", alpha=0.2
)
ax.errorbar(
[],
[],
yerr=[],
fmt="o",
color=colors.get(filt, "k"),
label="data",
alpha=0.5,
)
ax.set_xlabel("Log(Time)")
ax.set_ylabel("Flux relative to peak")
ax.set_title("Template Subtraction for {} in {}-band".format(sn.name, filt))
plt.legend()
# plt.show()
def plot_run_gp_overlay(
self,
ax,
test_times,
test_prediction,
std_prediction,
template_mags,
residuals,
log_transform,
filt,
sn=None,
):
if sn is not None:
# Convert between log fluxes to shifted magnitudes
log_fluxes = test_prediction + template_mags
shifted_mags = convert_shifted_fluxes_to_shifted_mags(
log_fluxes, sn, sn.zps[filt]
)
shifted_mags_lower_unc = convert_shifted_fluxes_to_shifted_mags(
log_fluxes - 1.96 * std_prediction, sn, sn.zps[filt]
)
shifted_mags_upper_unc = convert_shifted_fluxes_to_shifted_mags(
log_fluxes + 1.96 * std_prediction, sn, sn.zps[filt]
)
ax.plot(
test_times,
shifted_mags,
label=filt,
color=colors.get(filt, "k"),
)
ax.fill_between(
test_times,
shifted_mags_lower_unc,
shifted_mags_upper_unc,
alpha=0.2,
color=colors.get(filt, "k"),
)
else:
ax.plot(
test_times,
test_prediction + template_mags,
label=filt,
color=colors.get(filt, "k"),
)
ax.fill_between(
test_times,
test_prediction - 1.96 * std_prediction + template_mags,
test_prediction + 1.96 * std_prediction + template_mags,
alpha=0.2,
color=colors.get(filt, "k"),
)
ax.errorbar(
np.exp(residuals["Phase"].values) - log_transform,
residuals["Mag"].values,
yerr=residuals["MagErr"].values,
fmt="o",
color=colors.get(filt, "k"),
mec="k",
)
ax.set_xlabel("Normalized Time [days]")
ax.set_ylabel("Flux Relative to Peak")
if sn is not None:
plt.title(sn.name)
plt.legend()
def plot_run_gp_surface(self, gp_class, x, y, test_prediction_reshaped):
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.plot_surface(x, y, test_prediction_reshaped)
ax.set_xlabel("Phase Grid")
ax.set_ylabel("Wavelengths")
ax.set_zlabel("Fluxes")
plt.tight_layout()
plt.show()