import math
import logging
from pathlib import Path
import torch
import torch.nn.functional as F
import torch.utils.data
import torchinfo
import kdai
import kdai._logging
import kdai.train
import kdai.datasets
import kdtpp.vis as vis
import kdtpp.prob as prob
import kdtpp.inferspikes
import matplotlib as mpl
import matplotlib.pyplot as plt
import einops
import numpy as np
import functools
import typing
from typing import Literal
import pdb
from collections import defaultdict

_logger = logging.getLogger(__name__)


class ZuoTHPTrainable(kdai.train.BaseTrainable):
    """
    Zuo et al. Transformer Hawkes Process model.

    The underlying NN model outputs the parameters of a softplus hazard function
    of the form:

    h(t) = softplus( α*t + m_out, β)

    Where α and β are learned scalars and m_out is the output of the NN model
    where just prior we have m_out = W*h + b, for learnable W matrix and bias b.

    On reproducing THP
    ==================
    The THP paper's results are known to not be reproducible and for the
    reference implementation by Zuo et al. to have bugs. This was confirmed
    by the authors through private conversation with Yang et al., which the
    latter described in their paper "Transformer Embeddings of Irregularly
    Spaced Events and Their Participants". Particularly, there are issues
    and discrepancies in the implementation with respect to the likelihood
    calculation. As such, it is very difficult to know how to faithfully
    reproduce this paper. Here, we try to follow the spirit of the paper
    which we interpret as follows:

      - event times (not time deltas) are encoded using sinusoidal embeddings
      - a transformer model transforms encoded event times into a hidden state
      - the hidden state, along with learnable scalars, are used to parametrize a hazard function.
      - the hazard function has the form:
            h(t_{n}) = softplus( α*(t_{n}- t{n-1})/t_{n-1} + m_out, β)
            h(t) = softplus( α*dt + m_out, β)
      - the integral of the hazard function is done by calculating h(x) for
        many samples evenly spaced between 0 and dt.
      - in addition to likelihood based loss, a second time prediction head is
        used from which an MSE loss is calculated. The total loss is:
           0.8 * likelihood_loss + 0.2 * MSE_loss.
      - when making time predictions (not likelihood predictions), the model
        uses the time prediction head.

    Things that we change:

      - we use h(t) = softplus( α*dt + m_out, β)
      - we include normalization for the input data and output distribution
        (the original implementation hard-codes a multiplication factor of 1/100,
         which is motivated by the scale of the Stack Overflow dataset:
         https://github.com/SimiaoZuo/Transformer-Hawkes-Process/blob/9f2e563909c538302733f048131ae8c535ce02ec/Main.py#L72
         )
      - for the time prediction head, we use:
            t_pred = W h + b
        instead of the original implementation's:
            t_pred = W h
    as without a bias, there is no way to deal with datasets that have
        time deltas with a large offset from 0 compared to their spread.


    EasyTPP, Zuo and Yang's implementation
    ======================================
    There are 3 implementations of the Transformer Hawkes Process model that
    we reference here. EasyTPP is a library that claims to implement an improved
    version of Zuo's model where improvements were bug fixes that were discussed
    with the original author. This implementation seems to have originated
    from an implementation by Chenghao Yang, but was modified to fit a more
    general framework that works with multiple other models. At the time of
    developing our implementation, there were a few serious issues with
    EasyTPP's implementation, for example, the transformer model had no
    MLP block in any of its layers. It seems that many bugs may have
    slipped in when refactoring the code to be more general. Therefore,
    Yang's implementation seems more promising as a reference implementation,
    as it includes some unspecified fixes that the original authors agreed
    were needed.

    All three implementations have
    similarities and differences. Most notably, they differ in the form of
    the hazard function. The paper, in Equation 6, describes the hazard function
    as:
        λ(t) = softplus( α*(t-t_j)/t_j + w^T h + b, β )

    The first term can be written in terms of delta time, dt as: α * dt / t_j
    Interestingly, the t_j is not a time-delta, and so, as time progresses,
    the factor α/t_j approaches zero. This will be heavily affected by the
    scale of the time data, and the paper does not describe any normalisation
    considerations.


    In EasyTPP, as they also handle marks, they have m_out = h, and have
    separate named `layer_intensity_hidden` and `factor_intensity_base` for
    the W and b. Although, it seems their `layer_intenity_hidden` includes
    a bias, which might be an oversight, at it makes `factor_intensity_base`
    redundant. To enable comparison, it is worth noting that their α parameter
    is called `factor_intensity_decay`. EasyTPP don't seem to have a β
    parameter, whereas Zuo's and Yang's do. It's not clear whether this is
    a simplification or an oversight. We will keep the β parameter. Another
    departure from Zuo is that EasyTPP

    Simiao Zuo's implementation:    https://github.com/SimiaoZuo/Transformer-Hawkes-Process/tree/master
    Chenghao Yang's implementation: https://github.com/yangalan123/anhp-andtt/tree/master/thp
    EasyTPP's implementation:       https://github.com/ant-research/EasyTemporalPointProcess/blob/main/easy_tpp/model/torch_model/torch_thp.py


    """

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self.n_integration_samples = 256
        self.mse_weight = 0.2
        if init_weights:
            self.init_weights()

    def init_weights(self):
        self.model.input_norm.set_mean_sd(
            self.ds_mgr.dt_mean, self.ds_mgr.dt_sd
        )

    def loss_fn(self, log_prob, m_out_t, mask, target_t):
        """With causal loss training trick.

        Zuo et al. isn't so clear about the exact loss function. They present
        two:
          1. timing likelihood from softplus hazard function and mark likelihood
          2. add a MSE loss to the above

        In their implementation, they use 2) only. The presence of the two
        losses is a symptom of likelihood and MSE loss being fundamentally
        different tasks, and which task is to be tackled is a choice that
        should be made from the onset.
        """
        # Target is only time. target_t is array of ts (causal training trick).
        assert len(target_t.shape) == 3
        b, s, c = target_t.shape
        assert mask.shape[0] == target_t.shape[0] == m_out_t.shape[0] == b
        assert mask.shape[1] == target_t.shape[1] == m_out_t.shape[1] == s
        # Easiest to just make everything "b s", as log_prob is "b s".
        m_out_t = einops.rearrange(m_out_t, "b s 1 -> b s")
        target_t = einops.rearrange(target_t, "b s 1 -> b s")
        m_out_t = mask * m_out_t
        target_t = mask * target_t

        # Zuo's implementation uses sum()/100, where 100 is hard-coded. It
        # seems more reasonable to use mean(), which will provide a better
        # way to scale the loss that is responsive to batch size and sequence
        # length.
        # https://github.com/SimiaoZuo/Transformer-Hawkes-Process/blob/9f2e563909c538302733f048131ae8c535ce02ec/Main.py#L72
        ave_mask = mask / mask.sum(dim=1, keepdim=True)

        def masked_mean(x):
            per_sample_mean = (x * ave_mask).sum(dim=1)
            batch_mean = per_sample_mean.mean()
            assert batch_mean.ndim == 0, batch_mean.shape
            return batch_mean

        # ave_mask is for the sequence length dimension, to ignore padding.
        # In the end, we want the mean over the batch dimension.
        mean_log_prob = masked_mean(log_prob)
        # There is an option: do MSE in data scale or model scale. Zuo does
        # it in the model scale, so we will copy that.
        target_t = self._dt_norm(target_t)
        mse_loss = masked_mean(F.mse_loss(m_out_t, target_t, reduction="none"))
        mse_loss = mse_loss * self.mse_weight
        loss = -mean_log_prob + mse_loss
        return loss, -mean_log_prob, mse_loss

    def _dt_norm(self, dt):
        """Normalize time deltas.

        We scale times as opposed to rescaling the hazard function as the
        form of the hazard function makes it difficult to rescale.
        """
        return dt / self.ds_mgr.dt_sd

    def _dt_denorm(self, dt):
        return dt * self.ds_mgr.dt_sd

    def _logprob_denorm(self, log_prob):
        # To keep likelihoods comparable, we must express them in terms of the
        # original data's scale. We achieve this by multiplying by the scale
        # factor (add in log space).
        log_prob = log_prob + math.log(self.model.input_norm.sd)
        return log_prob

    def _logprob_norm(self, log_prob):
        log_prob = log_prob - math.log(self.model.input_norm.sd)
        return log_prob

    def softplus_hazard(self, h_out, target_dt):
        """Hazard function used by Zuo et al.

        Arguably softmax is not a good functional form for the hazard function.
        The argument for it is that it turns (-inf, inf) into (0, inf);
        however, it isn't motivated any further than this, and the need to
        sample the integration is a downside.

        Another drawback of the softplus hazard function is that it is very
        difficult to position and scale the distribution over the data. We
        end up simply scaling the input data, which is not ideal, as the
        data may simply have a large offset, and the scale can cause the
        spread of the hazard function for the rescaled data to be very small.
        """
        B, S, C = target_dt.shape
        assert C == 1
        assert h_out.shape == target_dt.shape
        # Record the integration length before normalizing.
        # The integration length must be in the data scale.
        integration_len = target_dt[:, :, 0]
        target_dt = self._dt_norm(target_dt)
        sampled_dts = target_dt * einops.rearrange(
            torch.linspace(
                0,
                1,
                self.n_integration_samples,
                dtype=torch.float,
                device=h_out.device,
            ),
            "s -> 1 1 s",
        )
        assert sampled_dts.shape == (B, S, self.n_integration_samples)
        # Equation 6 in the paper. Modified to use dt directly.
        hs = self.model.beta * F.softplus(
            (self.model.alpha * sampled_dts + h_out) / self.model.beta
        )
        assert hs.shape == (B, S, self.n_integration_samples)
        log_h = torch.log(hs[:, :, -1])
        # One integral for every timestep.
        # If we only cared about log_prob, then we wouldn't need to use the
        # data scale for the integration length, as the scale would cancel out.
        # But as int_h is being returned, it is better to scale both terms,
        # int_h and log_h.
        int_h = hs.mean(dim=-1) * integration_len
        log_prob = log_h - int_h
        return log_prob, log_h, int_h

    def softplus_inth(self, h_out, from_dt, to_dt):
        B, S, C = from_dt.shape
        assert C == 1
        assert h_out.shape == from_dt.shape == to_dt.shape
        # Integration length must be in the data scale.
        int_len = to_dt - from_dt
        from_dt = self._dt_norm(from_dt)
        to_dt = self._dt_norm(to_dt)
        sampled_dts = from_dt + (to_dt - from_dt) * einops.rearrange(
            torch.linspace(
                0,
                1,
                self.n_integration_samples,
                dtype=torch.float,
                device=h_out.device,
            ),
            "s -> 1 1 s",
        )
        assert sampled_dts.shape == (B, S, self.n_integration_samples)
        # Equation 6 in the paper. Modified to use dt directly.
        hs = self.model.beta * F.softplus(
            (self.model.alpha * sampled_dts + h_out) / self.model.beta
        )
        assert hs.shape == (B, S, self.n_integration_samples)
        # One integral for every timestep.
        int_h = hs.mean(dim=-1) * int_len
        return int_h

    def cforward(self, sample):
        """Forward call that may be compiled.

        Route model calls that also need loss calculated though this fn."""
        x, mask, y = sample
        x = x.float().cuda()
        mask = mask.float().cuda()
        y = y.float().cuda()
        h_out, t_out = self.model(x, mask)
        log_prob, _, int_h = self.softplus_hazard(h_out, y)
        loss, nll_loss, mse_loss = self.loss_fn(log_prob, t_out, mask, y)
        return h_out, t_out, log_prob, int_h, loss, nll_loss, mse_loss

    def forward(self, sample):
        h_out, t_out, _, _, loss, *_ = self.cforward(sample)
        m_out = torch.stack([h_out, t_out], dim=2)
        last_m_out = m_out[:, -1]
        return last_m_out, loss

    def last_forward(self, x, mask, y):
        """Just a convenience function. cforward has too many outputs."""
        _, _, log_prob, int_h, *_ = self.cforward((x, mask, y))
        # only the last timestep
        log_prob = log_prob[:, -1]
        # Don't forget to denormalize the log_prob.
        return log_prob, int_h[:, -1]

    def interval_log_prob(self, x, mask, y, interval_len):
        # Currently, this only works for the last timestep.
        # x, y and interval_len are all in the original data scale.
        x = x.float().cuda()
        mask = mask.float().cuda()
        mask = einops.rearrange(mask[:, -1], "b -> b 1")
        assert torch.all(mask == 1), "Last timestep should not be masked."
        h_out, t_pred = self.model(x, mask)
        h_out = einops.rearrange(h_out[:, -1, :], "b c -> b 1 c")
        t_min = torch.clip(y - interval_len / 2, 0)
        t_min = einops.rearrange(t_min[:, -1, :], "b c -> b 1 c")
        t_max = t_min + interval_len
        _, _, int_h = self.softplus_hazard(h_out, t_min)
        int_ab_h = self.softplus_inth(h_out, t_min, t_max)
        # log_prob of interval from t_0 to t_1 is:
        # log(S_0 - S_1)
        #  = log(exp(-int_h0) - exp(-int_h1))
        #  = log(exp(-int_h0) - exp(-int_h0 - int_ab_h))
        #  = log(exp(-int_h0) - exp(-int_h0) * exp(-int_ab_h))
        #  = log(exp(-int_h0) * (1 - exp(-int_ab_h)))
        #  = log(exp(-int_h0)) + log(1 - exp(-int_ab_h))
        #  = -int_h0 + log(1 - exp(-int_ab_h))
        #  = -int_h0  + (-expml(-int_ab_h)).log()
        log_interval_prob = -int_h + (-torch.expm1(-int_ab_h)).log()
        return log_interval_prob

    def eval_metrics(self, dl, dtype=torch.float):
        loss_meter = kdai._logging.Meter("loss")
        nll_loss_meter = kdai._logging.Meter("nll_loss")
        mse_loss_meter = kdai._logging.Meter("mse_loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")
        to_last = lambda x: x[:, -1]
        for x, mask, y in dl:
            y = y.to(dtype).cuda()
            x = x.to(dtype).cuda()
            mask = mask.to(dtype).cuda()
            h_out, t_out, log_prob, int_h, loss, nll_loss, mse_loss = (
                self.cforward((x, mask, y))
            )
            N, L, _ = y.shape
            loss_meter.update(loss.item(), N)
            nll_loss_meter.update(nll_loss.item(), N)
            mse_loss_meter.update(mse_loss.item(), N)
            # loss and pred_nll are not the same.
            # And make sure to only get the log_prob for the last timestep.
            pred_nll.update(-to_last(log_prob).mean().item(), N)
            # Interval pred is the probability mass [t-1/2, t+1/2].
            if hasattr(self.ds_mgr, "density_interval_len"):
                interval_len = self.ds_mgr.density_interval_len
            else:
                interval_len = 1
            interval_pred_nll.update(
                -self.interval_log_prob(x, mask, y, interval_len=interval_len)
                .mean()
                .item(),
                N,
            )

            target = y[:, -1, 0]
            # THP has its own head for time prediction.
            pred = t_out[:, -1, 0]
            assert pred.shape == target.shape
            mean_abs_err.update((pred - target).abs().mean().item(), N)

        return {
            "loss": loss_meter.avg,
            "nll_loss": nll_loss_meter.avg,
            "mse_loss": mse_loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "interval_pred_nll": interval_pred_nll.avg,
        }

    def _eval_loss(self, dl):
        """Evaluate loss only."""
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, mask, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        """Evaluate with additional metrics."""
        metrics = self.eval_metrics(dl)

        def to_log_metric(m, v):
            if m == "loss":
                return kdai._logging.loss_metric(v)
            else:
                return kdai._logging.Metric(m, v, increasing=False)

        results = {
            "metrics": [
                *[to_log_metric(m, v) for m, v in metrics.items()],
                kdai._logging.Metric("model.alpha", self.model.alpha.item()),
                kdai._logging.Metric("model.beta", self.model.beta.item()),
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_figs(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_figs(self, dl, max_n_plots):
        sample = next(iter(dl))
        x, mask, y = sample
        dt_mean, dt_std = self.ds_mgr.dt_mean, self.ds_mgr.dt_sd
        max_t = dt_mean + 5 * dt_std
        figs = []
        for i in range(min(max_n_plots, x.shape[0])):
            figs.extend(self.plot_f_h_S(x[i], mask[i], max_t))
        return figs

    @torch.no_grad()
    def plot_f_h_S(self, x_0, mask, max_t):
        assert x_0.ndim == 2, "Only 1 input accepted (no batch dim)"
        x_0 = x_0.float().cuda()
        mask = mask.float().cuda()
        ts = torch.linspace(0, max_t, 1000)
        x_0 = einops.repeat(x_0, "c v -> t c v", t=ts.shape[0])
        mask = einops.repeat(mask, "v -> t v", t=ts.shape[0])
        log_prob, int_h = self.last_forward(
            x_0,
            mask,
            einops.rearrange(
                einops.repeat(ts, "t -> t v", v=x_0.shape[1]), "t v -> t v 1"
            ),
        )

        log_h = log_prob - int_h
        log_prob, log_h, int_h = log_prob.cpu(), log_h.cpu(), int_h.cpu()

        def plot_f():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ys = torch.exp(log_prob)
            ax.plot(ts, ys)
            ax.set_title("f(t) (probability density)")
            ax.set_xlabel("t")
            return fig

        def plot_S():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            S = torch.exp(-int_h)
            ax.plot(ts, S)
            ax.set_title("S(t) (survival function)")
            ax.set_xlabel("t")
            return fig

        def plot_h():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.plot(ts, torch.exp(log_h))
            ax.set_title("h(t) (hazard function)")
            ax.set_xlabel("t")
            return fig

        def plot_int_h():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.plot(ts, int_h)
            ax.set_title("integral of h from [0, t]")
            ax.set_xlabel("t")
            return fig

        figs = [plot_f(), plot_S(), plot_h(), plot_int_h()]
        plt.close("all")
        return figs


class DiscreteTrainable(kdai.train.BaseTrainable):
    """Trainable that outputs a discrete distribution over time.

    Works with EventSeqDatasets and MarkedEventSeqDatasets, which have format:

        x: [b, t, c]  (batch, events, channels)
        y: [b, c]     (batch, channels)

    "timestamp" of the event is the 0th channel in the data. Marks, if present,
    are the remaining channels.
    """

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        bin_mode: Literal["fixed", "auto"] = "fixed",
        causal: bool = True,
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self.is_causal = causal
        self.loss_fn = (
            self.loss_fn_causal if self.is_causal else self.loss_fn_noncausal
        )
        # For cross_entropy, we need to ignore the padding value.
        if hasattr(self.ds_mgr, "pad_val"):
            self.ignore_index = self.ds_mgr.pad_val
        else:
            # No padding, so just use the default for F.cross_entropy.
            self.ignore_index = -100

        self.bin_mode = bin_mode
        if self.bin_mode == "auto":
            self.t_min = max(0, self.ds_mgr.dt_mean - 3 * self.ds_mgr.dt_sd)
            self.t_max = self.ds_mgr.dt_mean + 3 * self.ds_mgr.dt_sd
            self.bin_width = (
                self.t_max - self.t_min
            ) / self.model.out_resolution
        elif self.bin_mode == "fixed":
            # For probability mass, we have discrete probability at 1, 2,
            # 3, ... n. Which we represent with edges (0.5, 1.5), (1.5, 2.5)...
            self.t_min = 0.5
            self.t_max = self.model.out_resolution + 0.5
            self.bin_width = 1
        else:
            raise ValueError(f"Invalid bin_mode: {self.bin_mode}")
        assert self.t_max > 0
        if init_weights:
            self.init_weights()
        # Speed-up tweak/hack.
        # The model has an out_resolution parameter—it determines the size of
        # the output layer. This trainable refers to that parameter in the
        # loss function; however, when compiling the model and the cforward
        # function (on pytorch 2.5), this reference to the out_resolution
        # parameter of the model prevents the compilation of the cforward
        # function. So, we make a record of the value here, and use it in the
        # loss function. This prevents an assertion error from dynamo.
        self.out_resolution = self.model.out_resolution

    @functools.lru_cache(maxsize=2)
    def bin_edges(self, device):
        res = torch.linspace(
            self.t_min, self.t_max, self.out_resolution + 1, device=device
        )
        return res

    def init_weights(self):
        if hasattr(self.model, "input_norm"):
            self.model.input_norm.set_mean_sd(
                self.ds_mgr.dt_mean, self.ds_mgr.dt_sd
            )
        elif hasattr(self.model, "init_input_scale"):
            self.model.init_input_scale(
                self.ds_mgr.dt_range_max, self.ds_mgr.dt_range_min
            )

    def _ts(self, device):
        return torch.linspace(
            self.t_min + self.bin_width / 2,
            self.t_max - self.bin_width / 2,
            self.model.out_resolution,
            device=device,
        )

    def extract_target(self, y):
        """Return the target time."""
        target_t = y[..., 0]
        return target_t

    def to_last(self, m_out, y):
        # Causal if y has shape [B, T, C], where C is typically 1, and T > 1.
        if self.is_causal:
            assert len(y.shape) == 3
            m_out = m_out[:, -1]
            y = y[:, -1]
        else:
            assert len(y.shape) == 2
        return m_out, y

    def loss_fn_causal(self, m_out_full, target_t):
        # Target is only time. target_t is array of ts (causal training trick).
        assert len(target_t.shape) == 3
        target_t = target_t[:, :, 0].float().cuda()
        t_idx = torch.floor((target_t - self.t_min) / self.bin_width).long()
        t_idx = torch.clamp(t_idx, 0, self.out_resolution - 1)
        res = F.cross_entropy(
            einops.rearrange(m_out_full, "b s r -> (b s) r"),
            einops.rearrange(t_idx, "b s -> (b s)"),
            ignore_index=self.ignore_index,
        )
        return res

    def loss_fn_noncausal(self, m_out, target_t):
        # Target is only time.
        assert len(target_t.shape) == 2
        assert torch.all(target_t != self.ignore_index), "Expected no padding"
        target_t = target_t[:, 0].float().cuda()
        t_idx = torch.floor((target_t - self.t_min) / self.bin_width).long()
        t_idx = torch.clamp(t_idx, 0, self.model.out_resolution - 1)
        res = F.cross_entropy(m_out, t_idx)
        return res

    def auto_ll(self, m_out, target_t):
        """Log-likelihood for the auto bin mode.
        This mode needs special treatment as the first and last bin are not
        the same as the others. The first bin is extended to 0, and the last
        is an exponential tail.
        """
        D = m_out.device
        assert len(target_t.shape) == 2
        assert torch.all(target_t != self.ignore_index), "Expected no padding"
        target_t = target_t[:, 0].float().to(D)
        t_idx = torch.floor((target_t - self.t_min) / self.bin_width).long()
        t_idx = torch.clamp(t_idx, 0, self.model.out_resolution - 1)
        # First bin has a standard bin's worth, plus the length to 0.
        first_bin_scale = torch.tensor(
            1 / (self.t_min + self.bin_width),
            dtype=torch.float32,
            device=D,
        )
        last_bin_scale = torch.exp(-(target_t - (self.t_max - self.bin_width)))
        mid_bin_scale = torch.tensor(
            1 / self.bin_width, dtype=torch.float32, device=D
        )
        scale = torch.cat(
            [
                first_bin_scale,
                mid_bin_scale.repeat(self.model.out_resolution - 2),
                last_bin_scale,
            ]
        )
        m_out = F.log_softmax(m_out, dim=1)
        lprobs = torch.einsum("b r, b r -> b", m_out, scale)
        t_one_hot = F.one_hot(t_idx, self.model.out_resolution).float()
        ll = torch.einsum("b r, b r -> b", lprobs, t_one_hot)
        return ll

    def cforward(self, sample):
        """Forward call that may be compiled.

        Route model calls that also need loss calculated though this fn."""
        x, mask, y = sample
        x = x.float().cuda()
        mask = mask.float().cuda()
        m_out = self.model(x, mask)
        loss = self.loss_fn(m_out, y)
        return m_out, loss

    def forward(self, sample):
        m_out, loss = self.cforward(sample)
        if self.is_causal:
            last_m_out = m_out[:, -1]
        else:
            last_m_out = m_out
        return last_m_out, loss

    def expected_val(self, m_out):
        ts = self._ts(m_out.device)
        assert len(ts) == self.model.out_resolution
        probs = F.softmax(m_out, dim=1)
        expected_val = (probs * ts).sum(dim=1)
        return expected_val

    def mode(self, m_out):
        ts = self._ts(m_out.device)
        assert len(ts) == self.model.out_resolution
        probs = F.softmax(m_out, dim=1)
        mode = ts[probs.argmax(dim=1)]
        return mode

    def median(self, m_out):
        """
        Calculate the least number x that satisfies
        P(t≤x) ≥ 0.5 and P(t≥x) ≥ 0.5.
        """
        x_idx = torch.argmax(
            # argmax doesn't work for bool.
            (F.softmax(m_out, dim=1).cumsum(dim=1) >= 0.5).int(),
            dim=1,
        )
        ts = self._ts(m_out.device)
        median = ts[x_idx]
        return median

    def infer(
        self,
        m_out,
        method: Literal["expected_val", "mode", "median"] = "expected_val",
    ):
        if method == "expected_val":
            return self.expected_val(m_out)
        elif method == "mode":
            return self.mode(m_out)
        elif method == "median":
            return self.median(m_out)
        else:
            raise ValueError(f"Invalid method: {method}")

    def _interval_log_prob(self, t, m_out, interval_len):
        if self.bin_mode == "fixed":
            raise ValueError(
                "Fixed bin mode currently doesn't need or support intervals."
            )
        from_t = torch.clip(t - interval_len / 2, 0)
        to_t = from_t + interval_len
        # m_out to probs
        probs = F.softmax(m_out, dim=1)
        res = prob.interval_prob3(probs, self.t_min, self.t_max, from_t, to_t)
        if torch.any(torch.isnan(res)):
            _logger.warning("NaN in interval log prob.")
        elif torch.any(torch.isinf(res)):
            _logger.warning("Inf in interval log prob.")
        assert torch.all(res <= 0), "Probability mass must be less-equal 1."
        return res

    def ll(self, m_out, target_t):
        assert len(target_t.shape) == 2
        if self.bin_mode == "auto":
            # The loss_fn_noncausal will do change to target_t, but auto_ll wont.
            target_t = target_t[:, 0].float().to(m_out.device)
            ll = prob.auto_ll(
                m_out,
                target_t,
                self.t_min,
                self.t_max,
                self.model.out_resolution,
            ).mean()
        elif self.bin_mode == "fixed":
            # Same as the cross entropy loss, just with scaling for bin width.
            ll = -(
                self.loss_fn_noncausal(m_out, target_t)
                + math.log(self.bin_width)
            )
        else:
            raise ValueError(f"Invalid bin_mode: {self.bin_mode}")
        return ll

    def eval_metrics(self, dl):
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        mean_abs_err_mode = kdai._logging.Meter("mean_abs_err_mode")
        mean_abs_err_mean = kdai._logging.Meter("mean_abs_err_mean")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")
        for X, mask, y in dl:
            N = X.shape[0]
            m_out_full, loss = self.cforward((X, mask, y))
            loss_meter.update(loss.item(), N)
            m_out, y = self.to_last(m_out_full, y)
            rescaled_ll = self.ll(m_out, y)
            pred_nll.update(-rescaled_ll.item(), N)

            y = y[:, 0].float().cuda()
            pred_exp_val = self.expected_val(m_out)
            pred_mode = self.mode(m_out)
            pred_median = self.median(m_out)
            mean_abs_err.update((pred_median - y).abs().mean().item(), N)
            mean_abs_err_mode.update((pred_mode - y).abs().mean().item(), N)
            mean_abs_err_mean.update((pred_exp_val - y).abs().mean().item(), N)
            if self.bin_mode == "auto":
                interval_pred_nll.update(
                    -self._interval_log_prob(
                        y, m_out, self.ds_mgr.density_interval_len
                    )
                    .mean()
                    .item(),
                    N,
                )
            else:
                # Just copy from pred_nll.
                if self.bin_width != self.ds_mgr.density_interval_len:
                    raise ValueError(
                        f"interval prob won't match mass. {self.bin_width=} "
                        f"!= {self.ds_mgr.density_interval_len=}"
                    )
                interval_pred_nll.update(-rescaled_ll.item(), N)
        return {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "mean_abs_err_mode": mean_abs_err_mode.avg,
            "mean_abs_err_mean": mean_abs_err_mean.avg,
            "interval_pred_nll": interval_pred_nll.avg,
        }

    def _eval_loss(self, dl):
        """Evaluate loss only."""
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, mask, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        metrics = self.eval_metrics(dl)

        results = {
            "metrics": [
                kdai._logging.loss_metric(metrics["loss"]),
                kdai._logging.Metric(
                    "mean_abs_err", metrics["mean_abs_err"], increasing=False
                ),
                kdai._logging.Metric(
                    "mean_abs_err_mean",
                    metrics["mean_abs_err_mean"],
                    increasing=False,
                ),
                kdai._logging.Metric(
                    "mean_abs_err_mode",
                    metrics["mean_abs_err_mode"],
                    increasing=False,
                ),
                kdai._logging.Metric(
                    "pred_nll", metrics["pred_nll"], increasing=False
                ),
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_dists(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_dists(self, dl, max_n_plots):
        sample = next(iter(dl))
        X, mask, t = sample
        m_out = self.model(X.float().cuda(), mask.float().cuda())
        if self.is_causal:
            m_out = m_out[:, -1]
            t = t[:, -1]
        return self._plot_dists(m_out, t, max_n_plots)

    def _plot_dists(self, m_out, target_t, max_n_plots):
        """Plots a distribution as a histogram; the target is highlighted."""
        batch_size = m_out.shape[0]
        probs = F.softmax(m_out, dim=1).cpu().numpy()
        target_t = target_t.cpu().numpy()
        ts = self._ts("cpu").numpy()
        assert len(ts) == self.model.out_resolution
        figs = []
        for i in range(min(max_n_plots, batch_size)):
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.set_xlim(self.t_min, self.t_max)
            ax.bar(ts, probs[i], width=self.bin_width)
            ax.axvline(target_t[i], color="r", linestyle="--")
            ax.set_title(f"t={target_t[i]}")
            figs.append(fig)
        return figs


VAR_MIN_BIN_WIDTH = 2 ** (-17)


class VarBinTrainable(kdai.train.BaseTrainable):
    """
    DiscreteTrainable, but with bins that vary to include equal probability mass.

    This trainable doesn't have a "fixed" and "auto" mode like
    DiscreteTrainable, as this trainable is only tailored for the probability
    density setting that "auto" mode is used for in DiscreteTrainable. If
    there is a fixed number of events, then just use DiscreteTrainable in
    "fixed" mode.

    Works with EventSeqDatasets and MarkedEventSeqDatasets, which have format:

        x: [b, t, c]  (batch, events, channels)
        y: [b, c]     (batch, channels)

    "timestamp" of the event is the 0th channel in the data. Marks, if present,
    are the remaining channels, although not used in this trainable.
    """

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        # bin_mode: Literal["fixed", "auto"] = "fixed",
        causal: bool = True,
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self.is_causal = causal
        self.loss_fn = (
            self.loss_fn_causal if self.is_causal else self.loss_fn_noncausal
        )
        # For cross_entropy, we need to ignore the padding value.
        if hasattr(self.ds_mgr, "pad_val"):
            self.ignore_index = self.ds_mgr.pad_val
        else:
            # No padding, so just use the default for F.cross_entropy.
            self.ignore_index = -100

        self._bin_edges = self.calculate_bin_edges(
            self.ds_mgr.train_dts_flat(), self.model.out_resolution
        )
        # Currently, we hardcode the exponential tail based on the second last
        # bin width. A better choice than just setting it to zero. Optionally,
        # this could be a learned parameter.
        self.last_bin_lambda = 1 / self._bin_edges[-2].item()
        self._bin_widths = self._bin_edges[1:] - self._bin_edges[:-1]
        assert type(self._bin_edges) == type(self._bin_widths) == torch.Tensor
        assert len(self._bin_edges) == len(self._bin_widths) + 1
        assert self._bin_edges[0] == 0 and self._bin_edges[-1] == np.inf
        if init_weights:
            self.init_weights()
        # Speed-up tweak/hack.
        # The model has an out_resolution parameter—it determines the size of
        # the output layer. This trainable refers to that parameter in the
        # loss function; however, when compiling the model and the cforward
        # function (on pytorch 2.5), this reference to the out_resolution
        # parameter of the model prevents the compilation of the cforward
        # function. So, we make a record of the value here, and use it in the
        # loss function. This prevents an assertion error from dynamo.
        self.out_resolution = self.model.out_resolution

    @functools.lru_cache(maxsize=2)
    def bin_edges(self, device=None):
        """Return bin edges on the desired or default device."""
        if device is None:
            device = self.in_device()
        return self._bin_edges.clone().to(device)

    @functools.lru_cache(maxsize=2)
    def bin_widths(self, device):
        """Return bin widths on the desired or default device."""
        if device is None:
            device = self.in_device()
        return self._bin_widths.clone().to(device)

    @functools.lru_cache(maxsize=2)
    def bin_centers(self, device):
        """Return bin centers on the desired or default device."""
        bin_edges = self.bin_edges(device)
        bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
        return bin_centers

    @staticmethod
    def bin_edges_from_quantiles(quantiles):
        min_bin_width = VAR_MIN_BIN_WIDTH
        edges = [0]
        for i in range(len(quantiles)):
            prev_edge = edges[-1]
            next_edge = quantiles[i]
            if next_edge - prev_edge < min_bin_width:
                # Bin too small.
                next_edge = prev_edge + min_bin_width
            edges.append(next_edge)
        edges.append(float("inf"))

        n_center_bins = len(quantiles) - 1
        assert len(edges) == n_center_bins + 2 + 1
        # Maintain 32bit representation, as some edges are very close together.
        bin_edges = torch.tensor(
            edges, dtype=torch.float32, requires_grad=False
        )
        # Check that numerical precision doesn't cause bins to be unexpectedly
        # small.
        tol = min_bin_width * 2 ** (-3)
        assert torch.all(
            torch.diff(bin_edges) >= min_bin_width - tol
        ), torch.diff(bin_edges).min()
        return bin_edges

    @classmethod
    def calculate_bin_edges(cls, train_dts, n_bins, edge_frac=None):
        """Calculate bin edges that have equal probability mass.
        Some caveats: the first bin will extend to 0, and the last bin will
        have an exponential tail extending to infinity. This is to prevent
        obtaining 0 likelihood on unseen data. Because these bins will extend
        their probability mass over a potentially large range, we would prefer
        these bins to be assigned less mass. In effect, we want to make it
        less likely that a sample falls into these lower density bins. To that
        effect, they will be given less mass, which is currently hardcoded to
        1/4 the mass of the other bins.
        """
        edge_bins = 2
        DEFAULT_EDGE_FRAC = 0.25
        if edge_frac is None:
            edge_frac = DEFAULT_EDGE_FRAC
        n_center_bins = n_bins - edge_bins
        center_bin_mass = 1 / (n_center_bins + 2 * edge_frac)
        edge_bin_mass = center_bin_mass * edge_frac
        prob_boundaries = np.linspace(
            edge_bin_mass,
            1 - edge_bin_mass,
            n_center_bins + 1,
        )
        quantiles = np.quantile(train_dts, prob_boundaries)
        edges = cls.bin_edges_from_quantiles(quantiles)
        return edges

    def get_bin_idx(self, target_t):
        """
        right=True
        ----------
        right=True means that if the query t is equal to the left edge of a bucket,
        then it will be placed in that bucket. Example for when it's important:

          Case:   t = 0, bin_edges = [0, 1, 2]
          Desired result: we want the bin whose edges are [0, 1] to be selected.

        -1
        --
        bucketize will match a bin like:  edges[i-1] <= t < edges[i], and then
        return i for a match. This means that i is always 1 greater than the
        actual bin index (0-based). So, we subtract 1 to get the correct index.
        """
        res = torch.bucketize(target_t, self.bin_edges(), right=True) - 1
        return res

    def init_weights(self):
        if hasattr(self.model, "input_norm"):
            self.model.input_norm.set_mean_sd(
                self.ds_mgr.dt_mean, self.ds_mgr.dt_sd
            )
        elif hasattr(self.model, "init_input_scale"):
            self.model.init_input_scale(
                self.ds_mgr.dt_range_max, self.ds_mgr.dt_range_min
            )

    def extract_target(self, y):
        """Return the target time."""
        target_t = y[..., 0]
        return target_t

    def to_last(self, m_out, y):
        # Causal if y has shape [B, T, C], where C is typically 1, and T > 1.
        if self.is_causal:
            assert len(y.shape) == 3
            m_out = m_out[:, -1]
            y = y[:, -1]
        else:
            # TODO: check why next line was used:
            # assert len(y.shape) == 2, y.shape
            # Now using:
            assert len(y.shape) == 3, y.shape
            y = y[:, -1]
        return m_out, y

    def loss_fn_causal(self, m_out_full, target_t):
        # Target is only time. target_t is array of ts (causal training trick).
        assert len(target_t.shape) == 3
        D = m_out_full.device
        target_t = target_t[:, :, 0].float().to(D)
        m_out_flat = einops.rearrange(m_out_full, "b s r -> (b s) r")
        target_flat = einops.rearrange(target_t, "b s -> (b s)")
        nll = -prob.var_bin_ll(
            m_out_flat, self.bin_edges(D), target_flat, self.last_bin_lambda
        )
        # The mean is over the masked indices also, which is fine for
        # training, but don't interpret this as likelihood e.g. when
        # evaluating.
        loss = torch.where(
            target_flat == self.ignore_index,
            torch.tensor(0.0, device=D),
            nll,
        )
        res = loss.mean()
        # Original. Ignorant of bin widths and exponential tail.
        # t_idx = self.get_bin_idx(target_t)
        # res = F.cross_entropy(
        #     einops.rearrange(m_out_full, "b s r -> (b s) r"),
        #     einops.rearrange(t_idx, "b s -> (b s)"),
        #     ignore_index=self.ignore_index,
        # )
        return res

    def loss_fn_noncausal(self, m_out, target_t):
        # (b, s 1) -> (b, )  | Only time channel, only last timestep.
        B, N = m_out.shape
        D = m_out.device
        assert N == self.out_resolution
        assert len(target_t.shape) == 3
        assert target_t.shape[0] == B
        assert target_t.shape[2] == 1
        target_t = target_t[:, -1, 0].float().to(D)
        assert target_t.shape == (B,)
        assert torch.all(target_t != self.ignore_index), "Expected no padding"
        nll = -prob.var_bin_ll(
            m_out, self.bin_edges(D), target_t, self.last_bin_lambda
        )
        t_idx = self.get_bin_idx(target_t)
        res = F.cross_entropy(m_out, t_idx)
        return res

    def ll(self, m_out, target_t):
        """Log-likelihood."""
        B, N = m_out.shape
        assert target_t.shape == (B, 1)
        D = m_out.device
        target_t = target_t[:, 0].float().to(D)
        assert torch.all(target_t != self.ignore_index), "Expected no padding"
        res = prob.var_bin_ll(
            m_out, self.bin_edges(D), target_t, self.last_bin_lambda
        ).mean()
        return res

    def cforward(self, sample):
        """Forward call that may be compiled.

        Route model calls that also need loss calculated though this fn."""
        x, mask, y = sample
        x = x.float().cuda()
        mask = mask.float().cuda()
        m_out = self.model(x, mask)
        loss = self.loss_fn(m_out, y)
        return m_out, loss

    def forward(self, sample):
        m_out, loss = self.cforward(sample)
        if self.is_causal:
            last_m_out = m_out[:, -1]
        else:
            last_m_out = m_out
        return last_m_out, loss

    def expected_val(self, m_out):
        B, L = m_out.shape
        res = prob.var_bin_expected_val(
            m_out, self.bin_edges(m_out.device), self.last_bin_lambda
        )
        return res

    @torch.no_grad()
    def mode(self, m_out):
        """no_grad() because grad won't work anyway, as we are using argmax."""
        # B, L = m_out.shape
        # bin_widths = self.bin_widths(m_out.device)
        # lprobs = F.log_softmax(m_out, dim=1)
        # lprob_density = lprobs / bin_widths
        # max_density_bin = torch.argmax(lprob_density, dim=1)
        # mode = self.bin_centers(m_out.device)[max_density_bin]
        mode = prob.var_bin_mode(
            m_out, self.bin_edges(m_out.device), self.last_bin_lambda
        )
        return mode

    @torch.no_grad()
    def median(self, m_out):
        """
        no_grad() because grad won't work anyway, as we are using searchsorted.
        """
        return prob.var_bin_median(
            m_out, self.bin_edges(m_out.device), self.last_bin_lambda
        )

    def infer(
        self,
        m_out,
        method: Literal["expected_val", "mode", "median"] = "expected_val",
    ):
        if method == "expected_val":
            return self.expected_val(m_out)
        elif method == "mode":
            return self.mode(m_out)
        elif method == "median":
            return self.median(m_out)
        else:
            raise ValueError(f"Invalid method: {method}")

    def _interval_log_prob(self, t, m_out, interval_len):
        from_t = torch.clip(t - interval_len / 2, 0)
        to_t = from_t + interval_len
        res = torch.log(
            prob.var_bin_interval_prob(
                m_out, self.bin_edges(m_out.device), from_t, to_t
            )
        )
        if torch.any(torch.isnan(res)):
            _logger.warning("NaN in interval log prob.")
        elif torch.any(torch.isinf(res)):
            _logger.warning("Inf in interval log prob.")
        assert torch.all(res <= 1e-5), "Probability mass must be less-equal 1."
        return res

    def eval_metrics(self, dl):
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        mean_abs_err_mode = kdai._logging.Meter("mean_abs_err_mode")
        mean_abs_err_mean = kdai._logging.Meter("mean_abs_err_mean")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")
        for X, mask, y in dl:
            N = X.shape[0]
            m_out_full, loss = self.cforward((X, mask, y))
            loss_meter.update(loss.item(), N)
            m_out, y = self.to_last(m_out_full, y)
            rescaled_ll = self.ll(m_out, y)
            pred_nll.update(-rescaled_ll.item(), N)

            y = y[:, 0].float().cuda()
            pred_exp_val = self.expected_val(m_out)
            pred_mode = self.mode(m_out)
            pred_median = self.median(m_out)
            mean_abs_err.update((pred_median - y).abs().mean().item(), N)
            mean_abs_err_mode.update((pred_mode - y).abs().mean().item(), N)
            mean_abs_err_mean.update((pred_exp_val - y).abs().mean().item(), N)
            interval_pred_nll.update(
                -self._interval_log_prob(
                    y, m_out, self.ds_mgr.density_interval_len
                )
                .mean()
                .item(),
                N,
            )
        return {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "mean_abs_err_mode": mean_abs_err_mode.avg,
            "mean_abs_err_mean": mean_abs_err_mean.avg,
            "interval_pred_nll": interval_pred_nll.avg,
        }

    def _eval_loss(self, dl):
        """Evaluate loss only."""
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, mask, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        metrics = self.eval_metrics(dl)

        results = {
            "metrics": [
                kdai._logging.loss_metric(metrics["loss"]),
                kdai._logging.Metric(
                    "mean_abs_err", metrics["mean_abs_err"], increasing=False
                ),
                kdai._logging.Metric(
                    "mean_abs_err_mean",
                    metrics["mean_abs_err_mean"],
                    increasing=False,
                ),
                kdai._logging.Metric(
                    "mean_abs_err_mode",
                    metrics["mean_abs_err_mode"],
                    increasing=False,
                ),
                kdai._logging.Metric(
                    "pred_nll", metrics["pred_nll"], increasing=False
                ),
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_dists(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_dists(self, dl, max_n_plots):
        sample = next(iter(dl))
        X, mask, t = sample
        m_out = self.model(X.float().cuda(), mask.float().cuda())
        if self.is_causal:
            m_out = m_out[:, -1]
            t = t[:, -1]
        return self._plot_dists(m_out, t, max_n_plots)

    def _plot_dists(self, m_out, target_t, max_n_plots):
        """Plots a distribution as a histogram; the target is highlighted."""
        batch_size = m_out.shape[0]
        probs = F.softmax(m_out, dim=1).cpu().numpy()
        target_t = target_t.cpu().numpy()
        figs = []
        edges = self.bin_edges("cpu").numpy()
        widths = edges[1:] - edges[0:-1]
        # Don't use infinity for last bin. Just twice the previous bin.
        widths[-1] = widths[-2] * 2
        for i in range(min(max_n_plots, batch_size)):
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.bar(
                edges[:-1],
                probs[i],
                width=widths,
                align="edge",
                edgecolor="black",
            )
            ax.set_xlim(0, edges[-2] + widths[-2])
            ax.set_title(f"t={target_t[i]}")
            figs.append(fig)
        return figs


def init_gompertz_output(dt_mean, dt_sd):
    """Roughly center and scale the Gompertz distribution over the data.

    This is an improvement over the default initialization.

    The mean and variance of the Gompertz distribution are both affected
    by the distribution's 2 parameters: changing 1 changes both the
    mean and variance. Furthermore, the mapping from parameter to moment
    is non-linear and hard to invert. While it is tempting to abandon
    trying to sensibly initialize the model, this will lead to very
    slow convergence and possible initial instability. Our approach is
    to use the approximate functions that map the parameters to their
    moments described by Adam Lenart in:

    ["The Gompertz Distribution and Maximum Likelihood Estimation of its
    Parameters" (2012)](https://www.demogr.mpg.de/papers/working/wp-2012-008.pdf).

    Mean approximation (in the paper's notation):

    mean ≈ 1/b * exp(a/b) * (a/b - log(a/b) - 0.75522)
    var  ≈ 1/b² * π²/6 - 2*a/b³

    Even these approximations are not linear, so we will use a numerical
    optimization to find good starting parameters. A simple gradient
    descent does well in practice and seems to find initial parameters
    that are much better than the default initializations.

    If inspecting these parameters, it is worth being aware than a and
    b both get lower as the mean and variance increase.

    It's noteworthy that we added an extra scalar parameter to the model
    to allow the model offset to be set. We use a learnable parameter
    instead of a fixed offset, just so that we don't prevent the model
    from finding a better offset. This isn't really a change to the model
    as it already has a bias term in the output layer. We could have
    instead initialized this term; however, we felt that an explicit
    parameter made the intent clearer.
    """

    def loss(a, b):
        a = torch.exp(a)
        b = torch.exp(b)
        approx_mean = (
            1 / b * torch.exp(a / b) * (a / b - torch.log(a / b) - 0.75522)
        )
        approx_var = 1 / b**2 * math.pi**2 / 6 - 2 * a / b**3
        # Scaled mean difference.
        # We are going to overshoot the variance slightly, as we are most
        # concerned with assigning events probability approaching zero,
        # which can generate very large losses.
        VAR_OVERSHOOT = 1.2
        target_var = dt_sd * VAR_OVERSHOOT
        var_loss = torch.abs(target_var - approx_var)
        mean_loss = torch.abs((dt_mean - approx_mean)) / dt_sd
        loss = var_loss + mean_loss
        return loss

    n_iter = int(1e4)
    a = torch.tensor(0.0, requires_grad=True, dtype=torch.float64)
    b = torch.tensor(0.0, requires_grad=True, dtype=torch.float64)
    optimizer = torch.optim.AdamW([a, b], lr=1e-3)
    l = 0
    for i in range(n_iter):
        optimizer.zero_grad()
        l = loss(a, b)
        l.backward()
        optimizer.step()
    # Back to our parameterization
    # Our "a" parameter is Adam Lenart's "b" and includes the negative.
    # Our "d" parameter is log(a/b), but as we transform a and
    # b by exp in loss(), a and b are already in the log form, and divide
    # becomes subtract.
    a = -b.detach().exp().float()
    d = (a - b).detach().float()
    _logger.info(
        f"Gompertz init optimization. Loss: {l.item():.4e} after "
        f"{n_iter} iterations."
    )
    return a, d


class GptHazardTrainable(kdai.train.BaseTrainable):
    """Const, exp and nn hazard heads on a transformer model.

    Two major changes from OmiTrainable:
      - causal training (so extra dimension for loss calculation)
      - data is converted to timestep embeddings. The reason this
        calls for a separate API is that the initialization of the
        model now requires specifying the maximum and minimum time
        scales that we consider to be important over the full range
        of the model's input. This is very different to the mean and
        sd of the time deltas used by the OmiTrainable models.

    Both changes are due to the nature of the transformer model.
    """

    HazardType = Literal["const", "exp", "nn"]

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        hazard_type: HazardType,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        # if hazard_type not in self.HAZARD_TYPES:
        # modern approach
        if hazard_type not in typing.get_args(self.HazardType):
            raise ValueError(
                "Invalid hazard type. Expected one of "
                f"({typing.get_args(self.HazardType)}). Got ({hazard_type})."
            )
        self.hazard_type = hazard_type
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self.hazard_forward = getattr(
            self, f"{self.hazard_type}_hazard_forward"
        )
        # Copy how DiscreteTrainable does causal training. The y-values will
        # be inspected for a pad value, and those values will not contribute
        # to the loss.
        if hasattr(self.ds_mgr, "pad_val"):
            self.ignore_index = self.ds_mgr.pad_val
        else:
            # No padding, so just use the default for F.cross_entropy.
            self.ignore_index = -100
        # This can't be called automatically, as it's not needed when loading
        # a pre-trained model.
        if init_weights:
            self.init_normalization()

    def init_normalization(self):
        """Initialize model input and output normalization."""
        mean, sd = self.ds_mgr.dt_mean, self.ds_mgr.dt_sd
        _logger.info(
            "train ds:\n"
            f"(mean, sd): ({mean:.3g}, {sd:.3g})\n"
            f"(min, max): ({self.ds_mgr.dt_min:.3g}, {self.ds_mgr.dt_max:.3g})\n"
            f"dt range (min, max): ({self.ds_mgr.dt_range_min:.3g}, "
            f"{self.ds_mgr.dt_range_max:.3g})"
        )
        self.model.init_input_scale(
            self.ds_mgr.dt_range_max, self.ds_mgr.dt_range_min
        )
        # Each model has their own way of initializing the output layer.
        output_init_fn = getattr(self, f"_init_{self.hazard_type}_output")
        output_init_fn()

    def _init_nn_output(self):
        """Nothing to do for the output, but we do have a query to scale."""
        self.model.init_query_scale(self.ds_mgr.dt_mean, self.ds_mgr.dt_sd)

    def _init_const_output(self):
        """Initialize the constant hazard model.

        Assuming m_out starts around zero, we will set the offset to make the
        output distribution match the data mean. This amounts to setting the
        offset to log(1/dt_mean).
        """
        self.model.init_output(-math.log(self.ds_mgr.dt_mean))

    def _init_exp_output(self):
        a, d = init_gompertz_output(self.ds_mgr.dt_mean, self.ds_mgr.dt_sd)
        self.model.init_output(a, d)
        print(f"Initialized ExpHazard model with {a=:.4e}, {d=:.4e}.")

    def loss_fn(self, log_prob, mask):
        b, s = log_prob.shape
        n_elem = b * s
        mean_nll = -((log_prob * mask) / n_elem).sum()
        return mean_nll

    @staticmethod
    def const_hazard_forward(model, x, mask, y):
        m_out = model(x, mask)
        log_prob, log_h, int_h = prob.const_hazard(m_out, y)
        return m_out, log_prob, log_h, int_h

    @staticmethod
    def exp_hazard_forward(model, x, mask, y):
        a = torch.ones_like(y) * model.a
        m_out = model(x, mask)
        log_prob, log_h, int_h = prob.exp_hazard(a, m_out, y)
        return m_out, log_prob, log_h, int_h

    def nn_hazard_forward(self, model, x, mask, y):
        m_out, log_h, int_h = model(x, mask, y)
        if not torch.isfinite(log_h).all():
            logging.error(f"log_h is not finite ({log_h})")
        assert torch.isfinite(int_h).all()
        assert torch.isfinite(m_out).all()
        log_prob = log_h - int_h
        return m_out, log_prob, log_h, int_h

    def to_last(self, tensor):
        return tensor[:, -1]

    def _forward(self, model, x, mask, y):
        assert y.dim() in {1, 2}, "supports (B, 1) and (B, S) shapes."
        shape = y.shape
        m_out, log_prob, log_h, int_h = self.hazard_forward(model, x, mask, y)
        assert (
            m_out.shape == log_prob.shape == log_h.shape == int_h.shape == shape
        ), f"{m_out.shape=} {log_prob.shape=} {log_h.shape=} {int_h.shape=} {shape=}"
        loss = self.loss_fn(log_prob, mask)
        return (
            self.to_last(m_out),
            loss,
            self.to_last(log_prob),
            self.to_last(log_h),
            self.to_last(int_h),
        )

    def forward(self, sample):
        x, mask, y = sample
        x = x.float().cuda()
        y = y.float().cuda()
        mask = mask.float().cuda()
        # (b, s 1) -> (b, s)  | Only time channel
        y = y[:, :, 0]
        m_out, loss, _, _, _ = self._forward(self.model, x, mask, y)
        return m_out, loss

    def interval_log_prob(self, x, mask, y, interval_len):
        t_min = torch.clip(y - interval_len / 2, 0)
        t_max = t_min + interval_len
        p0, _, lp_min, _, int_h0 = self._forward(self.model, x, mask, t_min)
        p1, _, lp_max, _, int_h1 = self._forward(self.model, x, mask, t_max)
        if not torch.isfinite(int_h0).all():
            pdb.set_trace()
            p0, l, lp_min, log_h, int_h0 = self._forward(
                self.model, x, mask, t_min
            )
        log_prob_via_integral = prob.logsubexp(-int_h0, -int_h1)
        # If the integral version fails (infinity, due to numerical issues),
        # then we fall back to using a constant probability over the interval.
        # When we fall back, we are approximating with the lower probability
        # uniform over the interval.
        log_prob_via_lprob = torch.min(lp_min, lp_max) * interval_len
        log_prob = torch.where(
            torch.isfinite(log_prob_via_integral),
            log_prob_via_integral,
            log_prob_via_lprob,
        )
        assert torch.isfinite(log_prob).all()
        return log_prob

    def infer(self, x, mask):
        """Infer *median* for last timestep.

        Only last timestep is considered. Short inputs are supported through
        datasets having left-padding.

        TODO: it would be good ta have a mean and mode method as well.
        Median is chosen as mean-abs-err is the metric currently used. Mean
        should be considered if mean-squared-error is used.
        """
        B, L, C = x.shape
        # Bisect [low, high] looking for the median of the hazard function.
        # At the median, the cumulative distribution is 1/2, so the
        # integral hazard will be -log(1/2) = log(2). So we are looking to
        # zoom in to log(2) as the value for the integral hazard.
        # GPTs don't need input_norm, as they instead scale their positional
        # encoding. So use the trainable's access to the ds_mgr.
        min_t = self.ds_mgr.dt_mean - 3 * self.ds_mgr.dt_sd
        max_t = self.ds_mgr.dt_mean + 3 * self.ds_mgr.dt_sd
        low = torch.full(size=[B], fill_value=min_t, device=x.device)
        high = torch.full(size=[B], fill_value=max_t, device=x.device)
        for _ in range(13):
            mid = (low + high) / 2
            # _forward expects (B, T) shape. We could add if-elses to treat
            # the (B, 1) case differently. Currently, we just repeat the
            # single value to the sequence length.
            mid = einops.repeat(mid, "b -> b t", t=L)
            _, _, _, _, int_h = self._forward(self.model, x, mask, mid)
            # int_h is positive, with Survival(t) = exp(-int_h)
            # Targetting S(t_0) = 0.5, we want int_h = log(2), and note that
            # int_h is increasing in t.
            high = torch.where(int_h > math.log(2), mid[:, 0], high)
            low = torch.where(int_h < math.log(2), mid[:, 0], low)
        res = mid[:, 0]
        assert res.shape == (B,)
        return res

    def eval_metrics(self, dl, dtype=torch.float):
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")
        for x, mask, y in dl:
            x = x.to(dtype).cuda()
            y = y.to(dtype).cuda()
            mask = mask.float().cuda()
            # (b, s, 1) -> (b, s)  | Only time channel.
            y = y[:, :, 0]
            N = y.shape[0]
            _, loss, log_prob, _, _ = self._forward(self.model, x, mask, y)
            loss_meter.update(loss.item(), y.shape[0])
            pred_nll.update(-log_prob.mean().item(), N)
            interval_pred_nll.update(
                -self.interval_log_prob(
                    x, mask, y, interval_len=self.ds_mgr.density_interval_len
                )
                .mean()
                .item(),
                N,
            )
            pred = self.infer(x, mask)
            assert pred.shape == (N,)
            # It's only the last timestep that we use.
            mean_abs_err.update((pred - self.to_last(y)).abs().mean().item(), N)
        return {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "interval_pred_nll": interval_pred_nll.avg,
        }

    def _eval_loss(self, dl):
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            x = x.float().cuda()
            y = y.float().cuda()
            mask = mask.float().cuda()
            # (b, s 1) -> (b, s)  | Only time channel.
            y = y[:, :, 0]
            N = y.shape[0]
            _, loss, _, _, _ = self._forward(self.model, x, mask, y)
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        metrics = self.eval_metrics(dl)
        results = {
            "metrics": [
                kdai._logging.loss_metric(metrics["loss"]),
                kdai._logging.Metric(
                    "mean_abs_err", metrics["mean_abs_err"], increasing=False
                ),
                kdai._logging.Metric(
                    "pred_nll", metrics["pred_nll"], increasing=False
                ),
                kdai._logging.Metric(
                    "interval_pred_nll",
                    metrics["interval_pred_nll"],
                    increasing=False,
                ),
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_figs(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_figs(self, dl, max_n_plots):
        sample = next(iter(dl))
        x, mask, y = sample
        x = x.float().cuda()
        mask = mask.float().cuda()
        figs = []
        for i in range(min(max_n_plots, x.shape[0])):
            figs.extend(self.plot_f_h_S(x[i], mask[i]))
        return figs

    @torch.no_grad()
    def plot_f_h_S(self, x_0, mask):
        assert x_0.ndim == 2, "Only 1 input accepted (no batch dim)"
        S, C = x_0.shape
        ts = torch.linspace(
            0, self.ds_mgr.dt_mean + 3 * self.ds_mgr.dt_sd, 1000
        )
        # _forward expects (B, T) shape. We will just repeat the ts along the
        # sequence dimension. Only the last timestep will be used.
        ts = einops.repeat(ts, "t -> t s", s=S)
        x_0 = einops.repeat(x_0, "s c -> t s c", t=ts.shape[0])
        mask = einops.repeat(mask, "s -> t s", t=ts.shape[0])
        _, _, log_prob, log_h, int_h = self._forward(
            self.model, x_0, mask, ts.cuda()
        )
        log_prob, log_h, int_h = log_prob.cpu(), log_h.cpu(), int_h.cpu()

        def plot_f():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ys = torch.exp(log_prob)
            ax.plot(ts, ys)
            ax.set_title("f(t) (probability density)")
            ax.set_xlabel("t")
            return fig

        def plot_S():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            S = torch.exp(-int_h)
            ax.plot(ts, S)
            ax.set_title("S(t) (survival function)")
            ax.set_xlabel("t")
            return fig

        def plot_h():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.plot(ts, torch.exp(log_h))
            ax.set_title("h(t) (hazard function)")
            ax.set_xlabel("t")
            return fig

        def plot_int_h():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.plot(ts, int_h)
            ax.set_title("integral of h from [0, t]")
            ax.set_xlabel("t")
            return fig

        figs = [plot_f(), plot_S(), plot_h(), plot_int_h()]
        plt.close("all")
        return figs

    def model_summary(self, batch_size: int) -> str:
        dl = torch.utils.data.DataLoader(self.train_ds(), batch_size=batch_size)
        x, mask, y = kdai.train.to_device(next(iter(dl)), self.in_device())
        # (b, s 1) -> (b, s)  | Only time channel.
        y = y[:, :, 0]
        if self.hazard_type == "nn":
            input_data = (x, mask, y)
        else:
            input_data = (x, mask)
        summary = torchinfo.summary(
            self.model,
            input_data=input_data,
            col_names=["input_size", "output_size", "mult_adds", "num_params"],
            device=self.in_device(),
            depth=4,
        )
        return str(summary)


class OmiTrainable(kdai.train.BaseTrainable):
    HazardType = Literal["const", "exp", "nn", "rmtpp"]

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        hazard_type: HazardType,
        use_log_input=False,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        # if hazard_type not in self.HAZARD_TYPES:
        # modern approach
        if hazard_type not in typing.get_args(self.HazardType):
            raise ValueError(
                "Invalid hazard type. Expected one of "
                f"({typing.get_args(self.HazardType)}). Got ({hazard_type})."
            )
        self.hazard_type = hazard_type
        if use_log_input:
            raise NotImplementedError("use_log_input not implemented.")
        self.use_log_input = use_log_input
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self._forward_fn = getattr(self, f"{self.hazard_type}_hazard_forward")
        # This can't be called automatically, as it's not needed when loading
        # a pre-trained model.
        if init_weights:
            self.init_normalization()

    def init_normalization(self):
        """Initialize model input and output normalization.

        All models have their inputs normalized to be centered and scaled
        around 0; this is the same for all models and is done to improve
        training dynamics.

        Outputs are expected to be in the unscaled and unshifted units of the
        data (e.g. seconds for the Stack Overflow Badge dataset). To
        improve training dynamics, the bias and scale parameters that appear
        just before the hazard/survival function parameterization are
        initialized so that the layers prior have a distribution roughly
        centered and scaled around 0. How to set these parameters is specific
        to the functional form used by each model.

        The parameters used for scaling/shifting are scalars and will not be
        affected by weight decay, so there is not an issue with setting large
        values for these parameters.
        """
        _logger.info(
            "train ds (mean,sd): "
            f"({self.ds_mgr.dt_mean:.3g}, {self.ds_mgr.dt_sd:.3g})"
        )
        # All models have a normalization layer that they will use for input
        # normalization. So set the mean and sd here.
        self.model.input_norm.set_mean_sd(
            *(
                (self.ds_mgr.log_dt_mean, self.ds_mgr.log_dt_sd)
                if self.use_log_input
                else (self.ds_mgr.dt_mean, self.ds_mgr.dt_sd)
            )
        )
        # Each model has their own way of initializing the output layer.
        output_init_fn = getattr(self, f"_init_{self.hazard_type}_output")
        output_init_fn()

    def _init_nn_output(self):
        """Setup output initialization for the NN integral hazard model.

        The hazard function is computed by the model directly, so its output
        is not parameters of a hazard function or probability distribution.
        So instead, the model can manipulate the time argument sent into its
        hazard function. It has the values in its input_norm module available
        to do this, so nothing needs to be done here.

        Furthermore, the output has the form:

            int_h = softmax(x) = log(1 + exp(x))

        Which means that when x=0, the survival probability is:

            S = exp(-int_h) = 1 / (1 + exp(x)) = 1 / (1 + 1) = 0.5

        Which is a reasonable initialization as it corresponds to the median
        of the distribution. If the distribution is highly skewed, it might
        be worth instead offsetting int_h so that the survival probability
        matches the data estimate of the survival probability at the mean.

        The default scaling is also not unreasonable. For example, when
        S = 0.05 or S = 0.95:

            S = 0.05
                ⇒ 0.05 = 1 / (1 + exp(x))
                ⇒ 19 = exp(x)
                ⇒ x = log(19) ≈ 2.94
            and,
            S = 0.95
                ⇒ 0.95 = 1 / (1 + exp(x))
                ⇒ 1.0526 - 1 = exp(x)
                ⇒ x = log(0.0526) ≈ -2.94

        Both of which values are neither too large nor too small. One possible
        tweak might be to aim for x at ±1 mapping to the values of the data
        distribution's survival values at μ ± σ (data mean and data sqrt(var)).
        For example, if the time deltas are Gaussian distributed, this would
        mean x=-1 maps through 1/(1+exp(β x)) to 0.8413, meaning that β gets
        assigned 1.67. Some more thought might be needed if the distribution
        is highly skewed.

        All in all, the default initialization is reasonable, and as a plus,
        it doesn't introduce another deviation from the original authors'
        implementation.
        """
        pass

    def _init_const_output(self):
        """Initialize the constant hazard model.

        Assuming m_out starts around zero, we will set the offset to make the
        output distribution match the data mean. This amounts to setting the
        offset to log(1/dt_mean).
        """
        self.model.init_output(-math.log(self.ds_mgr.dt_mean))

    def _init_exp_output(self):
        a, d = init_gompertz_output(self.ds_mgr.dt_mean, self.ds_mgr.dt_sd)
        self.model.init_output(a, d)
        print(f"Initialized OmiExp model with {a=:.4e}, {d=:.4e}.")

    @staticmethod
    def const_hazard_forward(model, x, mask, y):
        assert y.dim() == 1
        B = y.shape[0]
        m_out = model(x, mask)
        assert m_out.shape == (B,)
        log_prob, log_h, int_h = prob.const_hazard(m_out, y)
        assert log_prob.shape == (B,)
        return m_out, log_prob, log_h, int_h

    @staticmethod
    def exp_hazard_forward(model, x, mask, y):
        assert y.dim() == 1
        B = y.shape[0]
        a = torch.ones_like(y) * model.a
        m_out = model(x, mask)
        assert m_out.shape == (B,)
        log_prob, log_h, int_h = prob.exp_hazard(a, m_out, y)
        assert log_prob.shape == (B,)
        return m_out, log_prob, log_h, int_h

    def rmtpp_hazard_forward(self, model, x, mask, y):
        """
        The time delta is scaled then offset. The model output acts to offset.

               w--  m_out--  b--
                 |        |    |
             dt--*--------+----+--->log_intensity

        m_out is referred to as v_dot_h in the paper.

        Typical notation for the Gompertz intensity function is:
            h(t) = e^{a t + d}
            so
            log∘h(t) = a t + d
        """
        assert y.dim() == 1
        B = y.shape[0]
        m_out = model(x, mask)
        assert m_out.shape == (B,), m_out.shape
        d = m_out + model.b
        # softplus like Shchur et al.
        # https://github.com/shchur/ifl-tpp/blob/original-code/code/dpp/distributions/rmtpp.py
        a = torch.ones_like(y) * F.softplus(model.w)
        # inline:
        # log_h = a * y + d
        # int_h = 1 / a * (torch.exp(a * y + d) - torch.exp(d))
        log_prob, log_h, int_h = prob.exp_hazard(a, d, y)
        return m_out, log_prob, log_h, int_h

    def nn_hazard_forward(self, model, x, mask, y):
        assert y.dim() == 1
        B = y.shape[0]
        if self.use_log_input:
            eps = 1e-10
            y_eps = y + eps
            assert torch.all(y_eps > 0)
            y = torch.log(y_eps)
        rnn_out, log_h, int_h = model(x, mask, y)
        assert log_h.shape == int_h.shape == (B,)
        # assert torch.isfinite(log_h).all()
        if not torch.isfinite(log_h).all():
            logging.error(f"log_h is not finite ({log_h})")
        assert torch.isfinite(int_h).all()
        assert torch.isfinite(rnn_out).all()
        log_prob = log_h - int_h
        return rnn_out, log_prob, log_h, int_h

    def forward(self, sample):
        x, mask, y = sample
        x = x.float().cuda()
        y = y.float().cuda()
        mask = mask.float().cuda()
        # (b, s 1) -> (b, )  | Only time channel, only last timestep.
        y = y[:, -1, 0]
        m_out, log_prob, _, _ = self._forward_fn(self.model, x, mask, y)
        nll = -log_prob.mean()
        return m_out, nll

    def interval_log_prob(self, x, mask, y, interval_len):
        t_min = torch.clip(y - interval_len / 2, 0)
        t_max = t_min + interval_len
        p0, lp_min, _, int_h0 = self._forward_fn(self.model, x, mask, t_min)
        p1, lp_max, _, int_h1 = self._forward_fn(self.model, x, mask, t_max)
        log_prob_via_integral = prob.logsubexp(-int_h0, -int_h1)
        # If the integral version fails (infinity, due to numerical issues),
        # then we fall back to using a constant probability over the interval.
        # When we fall back, we are approximating with the lower probability
        # uniform over the interval.
        log_prob_via_lprob = torch.min(lp_min, lp_max) * interval_len
        log_prob = torch.where(
            torch.isfinite(log_prob_via_integral),
            log_prob_via_integral,
            log_prob_via_lprob,
        )
        if not torch.isfinite(log_prob).all():
            _logger.error(
                f"log_prob is not finite ({log_prob}). This should "
                "be very rare, due to a validation/test example being very "
                "far from the training data."
            )
        return log_prob

    def infer(self, x, mask):
        B, C, L = x.shape
        # Bisect [low, high] looking for the median of the hazard function.
        # At the median, the cumulative distribution is 1/2, so the
        # integral hazard will be -log(1/2) = log(2). So we are looking to
        # zoom in to log(2) as the value for the integral hazard.
        low = (
            torch.ones(size=[B], device=x.device)
            * self.model.input_norm.mean
            * 0.0001
        )
        high = torch.ones_like(low) * self.model.input_norm.mean * 100.0
        for _ in range(13):
            mid = (low + high) / 2
            _, _, _, int_h = self._forward_fn(self.model, x, mask, mid)
            # int_h is positive, with Survival(t) = exp(-int_h)
            # Targetting S(t_0) = 0.5, we want int_h = log(2), and note that
            # int_h is increasing in t.
            high = torch.where(int_h > math.log(2), mid, high)
            low = torch.where(int_h < math.log(2), mid, low)
        res = mid
        return res

    def eval_metrics(self, dl, dtype=torch.float):
        loss_meter = kdai._logging.Meter("loss")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")
        for x, mask, y in dl:
            x = x.to(dtype).cuda()
            y = y.to(dtype).cuda()
            mask = mask.float().cuda()
            # (b, s, 1) -> (b, )  | Only time channel, only last timestep.
            y = y[:, -1, 0]
            N = y.shape[0]
            _, log_prob, _, _ = self._forward_fn(self.model, x, mask, y)
            loss = -log_prob.mean().item()
            loss_meter.update(loss, y.shape[0])
            pred = self.infer(x, mask)
            assert pred.shape == y.shape
            mean_abs_err.update((pred - y).abs().mean().item(), N)
            interval_pred_nll.update(
                -self.interval_log_prob(
                    x, mask, y, interval_len=self.ds_mgr.density_interval_len
                )
                .mean()
                .item(),
                N,
            )
        return {
            "loss": loss_meter.avg,
            "pred_nll": loss_meter.avg,
            "mean_abs_err": mean_abs_err.avg,
            "interval_pred_nll": interval_pred_nll.avg,
        }

    def _eval_loss(self, dl):
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            x = x.float().cuda()
            y = y.float().cuda()
            mask = mask.float().cuda()
            # (b, s 1) -> (b, )  | Only time channel, only last timestep.
            y = y[:, -1, 0]
            N = y.shape[0]
            _, log_prob, _, _ = self._forward_fn(self.model, x, mask, y)
            loss = -log_prob.mean().item()
            loss_meter.update(loss, N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        metrics = self.eval_metrics(dl)
        results = {
            "metrics": [
                kdai._logging.loss_metric(metrics["loss"]),
                kdai._logging.Metric(
                    "mean_abs_err", metrics["mean_abs_err"], increasing=False
                ),
                kdai._logging.Metric(
                    "pred_nll", metrics["pred_nll"], increasing=False
                ),
                kdai._logging.Metric(
                    "interval_pred_nll",
                    metrics["interval_pred_nll"],
                    increasing=False,
                ),
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_figs(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_figs(self, dl, max_n_plots):
        sample = next(iter(dl))
        x, mask, y = sample
        x = x.float().cuda()
        mask = mask.float().cuda()
        figs = []
        for i in range(min(max_n_plots, x.shape[0])):
            figs.extend(self.plot_f_h_S(x[i], mask[i]))
        return figs

    @torch.no_grad()
    def plot_f_h_S(self, x_0, mask):
        assert x_0.ndim == 2, "Only 1 input accepted (no batch dim)"
        ts = torch.linspace(0, self.model.input_norm.mean * 5, 1000)
        x_0 = einops.repeat(x_0, "c v -> t c v", t=ts.shape[0])
        mask = einops.repeat(mask, "v -> t v", t=ts.shape[0])
        _, log_prob, log_h, int_h = self._forward_fn(
            self.model, x_0, mask, ts.cuda()
        )
        log_prob, log_h, int_h = log_prob.cpu(), log_h.cpu(), int_h.cpu()

        def plot_f():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ys = torch.exp(log_prob)
            ax.plot(ts, ys)
            ax.set_title("f(t) (probability density)")
            ax.set_xlabel("t")
            return fig

        def plot_S():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            S = torch.exp(-int_h)
            ax.plot(ts, S)
            ax.set_title("S(t) (survival function)")
            ax.set_xlabel("t")
            return fig

        def plot_h():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.plot(ts, torch.exp(log_h))
            ax.set_title("h(t) (hazard function)")
            ax.set_xlabel("t")
            return fig

        def plot_int_h():
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.plot(ts, int_h)
            ax.set_title("integral of h from [0, t]")
            ax.set_xlabel("t")
            return fig

        figs = [plot_f(), plot_S(), plot_h(), plot_int_h()]
        plt.close("all")
        return figs

    def model_summary(self, batch_size: int) -> str:
        dl = torch.utils.data.DataLoader(self.train_ds(), batch_size=batch_size)
        x, mask, y = kdai.train.to_device(next(iter(dl)), self.in_device())
        # (b, s 1) -> (b, )  | Only time channel, only last timestep.
        y = y[:, -1, 0]
        if self.hazard_type == "nn":
            input_data = (x, mask, y)
        else:
            input_data = (x, mask)
        summary = torchinfo.summary(
            self.model,
            input_data=input_data,
            col_names=["input_size", "output_size", "mult_adds", "num_params"],
            device=self.in_device(),
            depth=4,
        )
        return str(summary)


class ShchurLogMixTrainable(kdai.train.BaseTrainable):
    """Trainable created for next event predictions using a log-normal mixture."""

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        use_log_input=False,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self.use_log_input = use_log_input
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        if init_weights:
            self.init_weights()

    def init_weights(self):
        self.model.input_norm.set_mean_sd(
            *(
                (self.ds_mgr.log_dt_mean, self.ds_mgr.log_dt_sd)
                if self.use_log_input
                else (self.ds_mgr.dt_mean, self.ds_mgr.dt_sd)
            )
        )

    def denorm_mu_log_sigma(self, mu, log_sigma):
        # Note: the mu here is the mu of the log-normal distribution.
        mu = mu + self.ds_mgr.log_dt_mean
        log_sigma = log_sigma + self.ds_mgr.log_dt_sd
        return mu, log_sigma

    def _forward(self, sample):
        x, mask, y = sample
        # Not float, as dtype can be changed.
        x = x.cuda()
        y = y.cuda()
        mask = mask.cuda()
        log_tau, mu_scaled, log_sigma_scaled = self.model(x, mask)
        # (b, s 1) -> (b, )  | Only time channel, only last timestep.
        y = y[:, -1, 0]
        mu, log_sigma = self.denorm_mu_log_sigma(mu_scaled, log_sigma_scaled)
        nll = -prob.logmix_log_prob(log_tau, mu, log_sigma, y)
        # For the loss, we will clip each term to prevent infs. This is one
        # reason why the loss should not be used for evaluation.
        # loss = nll.clamp(max=20).mean() TODO
        loss = nll.clamp(max=int(1e4)).mean()
        m_out = torch.stack([log_tau, mu_scaled, log_sigma_scaled], dim=-1)
        return m_out, loss, nll

    def forward(self, sample):
        m_out, loss, _ = self._forward(sample)
        return m_out, loss

    def infer_mean(self, x, mask):
        x = x.float().cuda()
        mask = mask.cuda()
        log_tau, mu, log_sigma = self.model(x, mask)
        mu, log_sigma = self.denorm_mu_log_sigma(mu, log_sigma)
        expected_val = prob.logmix_expected_val(log_tau, mu, log_sigma)
        return expected_val

    def infer_median(self, x, mask):
        B, C, L = x.shape
        x = x.float().cuda()
        mask = mask.cuda()
        log_tau, mu, log_sigma = self.model(x, mask)
        mu, log_sigma = self.denorm_mu_log_sigma(mu, log_sigma)
        median = prob.logmix_median(log_tau, mu, log_sigma)
        return median

    def eval_metrics(self, dl, dtype=torch.float):
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")

        # mean_abs_err_mode = kdai._logging.Meter("mean_abs_err_mode")
        # mean_abs_err_median = kdai._logging.Meter("mean_abs_err_median")

        # Note: the loss per batch is averaged, so we are averaging this
        # again as we loop through each batch.
        for x, mask, y in dl:
            x = x.to(dtype).cuda()
            mask = mask.to(dtype).cuda()
            y = y.to(dtype).cuda()
            m_out, loss, nll = self._forward((x, mask, y))
            N, M, L = m_out.shape
            loss_meter.update(loss.item(), y.shape[0])
            # We can't use the loss directly, as it is clamped to 20.
            pred_nll.update(nll.mean().item(), y.shape[0])
            # Interval pred is the probability mass [t-1/2, t+1/2].
            assert m_out.shape == (y.shape[0], M, 3), m_out.shape
            log_tau, mu_scaled, log_sigma_scaled = m_out.unbind(dim=2)
            # Target is only time, for now.
            target = y[:, -1, 0]
            t_min = torch.clip(target - 0.5, min=0.0)
            t_max = t_min + 1
            # Slightly wastful, but we need to unscale again.
            mu, log_sigma = self.denorm_mu_log_sigma(
                mu_scaled, log_sigma_scaled
            )
            interval_pred_nll.update(
                -prob.logmix_interval_log_prob(
                    log_tau, mu, log_sigma, t_min, t_max
                )
                .mean()
                .item(),
                N,
            )

            # pred = self.infer_mean(x, mask)
            pred = self.infer_median(x, mask)
            assert pred.shape == target.shape
            mean_abs_err.update((pred - target).abs().mean().item(), N)

        return {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "interval_pred_nll": interval_pred_nll.avg,
            # "mean_abs_err_mode": mean_abs_err_mode.avg,
            # "mean_abs_err_median": mean_abs_err_median.avg,
        }

    def _eval_loss(self, dl):
        """Evaluate loss only."""
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, mask, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        """Evaluate with additional metrics."""
        metrics = self.eval_metrics(dl)

        def to_log_metric(m, v):
            if m == "loss":
                return kdai._logging.loss_metric(v)
            else:
                return kdai._logging.Metric(m, v, increasing=False)

        results = {
            "metrics": [
                *[to_log_metric(m, v) for m, v in metrics.items()],
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_dists(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_dists(self, dl, max_n_plots):
        sample = next(iter(dl))
        _, _, t = sample
        # We only care about the last timestep.
        t = t[:, -1, 0]
        m_out, _ = self.forward(sample)
        log_tau, mu_scaled, log_sigma_scaled = m_out.unbind(dim=2)
        mu, log_sigma = self.denorm_mu_log_sigma(mu_scaled, log_sigma_scaled)
        return vis.logmix_dist_plots(log_tau, mu, log_sigma, t, max_n_plots)


class GptLogMix(kdai.train.BaseTrainable):
    """Gpt with LogMix head."""

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        if init_weights:
            self.init_normalization()

    def init_normalization(self):
        """Initialize model input and output normalization."""
        mean_log, sd_log = self.ds_mgr.log_dt_mean, self.ds_mgr.log_dt_sd
        _logger.info(
            "train ds:\n"
            f"(mean-log, sd-log): ({mean_log:.3g}, {sd_log:.3g})\n"
            f"(min, max): ({self.ds_mgr.dt_min:.3g}, {self.ds_mgr.dt_max:.3g})\n"
            f"dt range (min, max): ({self.ds_mgr.dt_range_min:.3g}, "
            f"{self.ds_mgr.dt_range_max:.3g})"
        )
        self.model.init_input_scale(
            self.ds_mgr.dt_range_min, self.ds_mgr.dt_range_max
        )
        self.model.init_output(mean_log, sd_log)

    def loss_fn(self, log_tau, mu, log_sigma, mask, y):
        b, s = y.shape

        def flatten_seq(v):
            return v.flatten(0, 1)

        # logmix expects t>0, but our mask may be 0 or negative.
        # Set masked y's to 1. We will ignore the nll for these later.
        y = torch.where(mask > 0, y, torch.ones_like(y))
        nll = -prob.logmix_log_prob(
            flatten_seq(log_tau),
            flatten_seq(mu),
            flatten_seq(log_sigma),
            flatten_seq(y),
        )
        nll = einops.rearrange(nll, "(b s) -> b s", b=b, s=s)
        # Ignore padded values (masked)
        assert mask.shape == nll.shape
        # Scale by the number of elements (s * b) (and not mask.sum()), so that
        # elements in sequences with more masking are not given more weight.
        n_elem = b * s
        scaled_mask = mask.float() / (n_elem)
        # For the loss, we will clip each term to prevent infs. This is one
        # reason why the loss should not be used for evaluation.
        # If one element is infinity, clipping allows the rest of the batch
        # to still be trained with.
        # loss = (nll.clamp(max=20) * scaled_mask).sum()
        loss = (nll.clamp(max=int(1e4)) * scaled_mask).sum()
        return loss, nll

    def _forward(self, x, mask, y):
        log_tau, mu, log_sigma = self.model(x, mask)
        # (b, s 1) -> (b, s)  | Only time channel
        y = y[:, :, 0]
        loss, nll = self.loss_fn(log_tau, mu, log_sigma, mask, y)
        m_out = torch.stack([log_tau, mu, log_sigma], dim=-1)
        # For m_out and nll, just the last timestep is fine.
        m_out = m_out[:, -1]
        nll = nll[:, -1]
        return m_out, loss, nll

    def forward(self, sample):
        x, mask, y = sample
        # Not float, as dtype can be changed.
        x = x.cuda()
        y = y.cuda()
        mask = mask.cuda()
        m_out, loss, _ = self._forward(x, mask, y)
        return m_out, loss

    def infer_mean(self, x, mask):
        x = x.float().cuda()
        mask = mask.cuda()
        log_tau, mu, log_sigma = self.model(x, mask)
        expected_val = prob.logmix_expected_val(log_tau, mu, log_sigma)
        return expected_val

    def infer_median(self, x, mask):
        B, C, L = x.shape
        x = x.float().cuda()
        mask = mask.cuda()
        log_tau, mu, log_sigma = self.model(x, mask)
        median = prob.logmix_median(log_tau, mu, log_sigma)
        return median

    def eval_metrics(self, dl, dtype=torch.float):
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")

        # mean_abs_err_mode = kdai._logging.Meter("mean_abs_err_mode")
        # mean_abs_err_median = kdai._logging.Meter("mean_abs_err_median")

        # Note: the loss per batch is averaged, so we are averaging this
        # again as we loop through each batch.
        for x, mask, y in dl:
            x = x.to(dtype).cuda()
            mask = mask.to(dtype).cuda()
            y = y.to(dtype).cuda()
            m_out, loss, nll = self._forward(x, mask, y)
            # batch, mix, mix-params (for last timestep)
            N, M, _ = m_out.shape
            assert N == y.shape[0]
            assert nll.shape == (N,), nll.shape
            loss_meter.update(loss.item(), N)
            # Only care about the last timestep for evalutaion.
            pred_nll.update(nll.mean().item(), N)
            # Interval pred is the probability mass [t-1/2, t+1/2].
            assert m_out.shape == (N, M, 3), m_out.shape
            log_tau, mu, log_sigma = m_out.unbind(dim=2)
            # Target is the last timestep, and only time.
            target = y[:, -1, 0]
            t_min = torch.clip(target - 0.5, min=0.0)
            t_max = t_min + 1
            interval_pred_nll.update(0, N)
            # TEMP TODO numerical issues
            # interval_pred_nll.update(
            #     -prob.logmix_interval_log_prob(
            #         log_tau, mu, log_sigma, t_min, t_max
            #     )
            #     .mean()
            #     .item(),
            #     N,
            # )

            # pred = prob.logmix_expected_val(log_tau, mu, log_sigma)
            pred = prob.logmix_median(log_tau, mu, log_sigma)
            assert pred.shape == target.shape
            mean_abs_err.update((pred - target).abs().mean().item(), N)

        res = {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "interval_pred_nll": interval_pred_nll.avg,
            # "mean_abs_err_mode": mean_abs_err_mode.avg,
            # "mean_abs_err_median": mean_abs_err_median.avg,
        }
        # TEMP
        for c in range(self.model.n_mix):
            res[f"c{c}_μ"] = mu[:, c].mean().item()
            res[f"c{c}_log(σ)"] = log_sigma[:, c].mean().item()

        return res

    def _eval_loss(self, dl):
        """Evaluate loss only."""
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, mask, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        """Evaluate with additional metrics."""
        metrics = self.eval_metrics(dl)

        def to_log_metric(m, v):
            if m == "loss":
                return kdai._logging.loss_metric(v)
            else:
                return kdai._logging.Metric(m, v, increasing=False)

        results = {
            "metrics": [
                *[to_log_metric(m, v) for m, v in metrics.items()],
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_dists(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_dists(self, dl, max_n_plots):
        sample = next(iter(dl))
        _, _, t = sample
        # We only care about the last timestep.
        t = t[:, -1, 0]
        m_out, _ = self.forward(sample)
        log_tau, mu, log_sigma = m_out.unbind(dim=2)
        return vis.logmix_dist_plots(log_tau, mu, log_sigma, t, max_n_plots)


class LogMixForSpikes(kdai.train.BaseTrainable):
    """Trainable for stim-spike samples in, log-normal mixture out.

    Similar to ShchurLogMixTrainable, the difference in input has led to
    a separate class: normalization of input is different, there is no need to
    handle a mask, and cell ids are included in the input.
    """

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        eval_rec_cids=None,
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self.eval_rec_cids = eval_rec_cids
        if init_weights:
            self.init_weights()
        # Not great abstraction here, but we aren't exactly making a library.
        self.model_in_len = self.model.gpt_base.input_len

    def init_weights(self):
        """Initialize weights associated with normalization.

        The input has heterogeneous data along the last dimension, and needs
        to be normalized differently."""
        self.model.set_input_mean_sd(
            torch.full((5,), 0.5),
            torch.tensor([0.5, 0.5, 0.5, 0.5, 1]),
        )
        mean_log, sd_log = self.ds_mgr.log_dt_mean, self.ds_mgr.log_dt_sd
        self.model.init_output(mean_log, sd_log)

    @staticmethod
    def y_to_t(y):
        return y + 0.5

    def loss_fn(self, log_tau, mu, log_sigma, y):
        assert y.dim() == 1
        assert log_tau.dim() == mu.dim() == log_sigma.dim() == 2
        (b,) = y.shape
        t = self.y_to_t(y)

        nll = -prob.logmix_log_prob(log_tau, mu, log_sigma, t)
        assert nll.shape == (b,), nll.shape
        # For the loss, we will clip each term to prevent infs. This is one
        # reason why the loss should not be used for evaluation.
        # If one element is infinity, clipping allows the rest of the batch
        # to still be trained with.
        # loss = (nll.clamp(max=20) * (1 / b)).sum()
        loss = (nll.clamp(max=int(1e4)) * (1 / b)).sum()
        loss = nll.mean()
        return loss, nll

    def _forward(self, x, cell_id, y):
        log_tau, mu_scaled, log_sigma_scaled = self.model(x, cell_id)
        # For the loss, we will clip each term to prevent infs. This is one
        # reason why the loss should not be used for evaluation.
        loss, nll = self.loss_fn(log_tau, mu_scaled, log_sigma_scaled, y)
        m_out = torch.stack([log_tau, mu_scaled, log_sigma_scaled], dim=-1)
        return m_out, loss, nll

    def forward(self, sample):
        x, cell_id, y = sample
        # Not float, as dtype can be changed.
        x = x.cuda()
        y = y.long().cuda()
        cell_id = cell_id.long().cuda()
        m_out, loss, _ = self._forward(x, cell_id, y)
        return m_out, loss

    def infer_mean(self, x, cell_id):
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        log_tau, mu, log_sigma = self.model(x, cell_id)
        expected_val = prob.logmix_expected_val(log_tau, mu, log_sigma)
        return expected_val

    def infer_median(self, x, cell_id):
        B, C, L = x.shape
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        log_tau, mu, log_sigma = self.model(x, cell_id)
        median = prob.logmix_median(log_tau, mu, log_sigma)
        return median

    def ll(self, x, cell_id, t_min, t_max):
        B, C, L = x.shape
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        t_min = t_min.float().cuda()
        t_max = t_max.float().cuda()
        log_tau, mu, log_sigma = self.model(x, cell_id)
        interval_pred_ll = prob.logmix_interval_log_prob(
            log_tau, mu, log_sigma, t_min, t_max
        )
        return interval_pred_ll

    def median_and_ll(self, x, cell_id, t_min, t_max):
        """This function is a convenient combo used by inferspikes.

        Without it, inferspikes would need to call forward twice to get
        both pieces of information.

        In future, a better API design would to be to have a function:

            forward_and(x, cell_id, y=None, fns=(,))

        That would call the fns on the output of forward and return the
        results as a list. That way, the structure of the model outputs would
        not need to be known by the caller.

        We should expect t_min and t_max to always be positive, as whenever
        there are no spikes left, the correct interval should be [win, ∞).
        """
        B, C, L = x.shape
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        assert torch.all(t_min >= 0) and torch.all(t_max > t_min)
        t_min = t_min.float().cuda()
        t_max = t_max.float().cuda()
        log_tau, mu, log_sigma = self.model(x, cell_id)
        median = prob.logmix_median(log_tau, mu, log_sigma)
        interval_ll = prob.logmix_interval_log_prob(
            log_tau, mu, log_sigma, t_min, t_max
        )

        return median, interval_ll

    def eval_metrics(self, dl, ds, dtype=torch.float):
        recs = self._eval_recs(ds)
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")
        STRIDE = 80

        # 1. Spike train metrics
        strain_res = {}
        for rec in recs:
            for cid in rec.cell_ids:
                r = rec.cells({cid})
                assert r.cell_ids == [cid], r.cell_ids
                van_rossum, pcorr, schreiber, auto_ll = (
                    kdtpp.inferspikes.prob_auto_stats(
                        self, r, STRIDE, bin_ms=1000 / 992, sigma_ms=60
                    )
                )
                ll = kdtpp.inferspikes.rec_ll(
                    self, [r], STRIDE, dl.batch_size, dl.num_workers
                )
                if len(recs) > 1:
                    tag = f"r{r.name}_c{cid}"
                else:
                    tag = f"c{cid}"
                strain_res[f"van_rossum_τ60_{tag}"] = van_rossum
                strain_res[f"pcorr_σ60_{tag}"] = pcorr
                strain_res[f"schreiber_σ60_{tag}"] = schreiber
                strain_res[f"ll_{tag}"] = ll
                strain_res[f"auto_ll_{tag}"] = auto_ll

        # 2. Per-snippet metrics
        for x, cell_id, y in dl:
            N, _, L = x.shape
            x = x.to(dtype).cuda()
            cell_id = cell_id.long().cuda()
            y = y.long().cuda()
            m_out, loss, nll = self._forward(x, cell_id, y)
            # batch, mix, mix-params (for last timestep)
            _, M, c = m_out.shape
            assert c == 3, c
            assert N == y.shape[0] == m_out.shape[0]
            assert nll.shape == (N,), nll.shape
            loss_meter.update(loss.item(), N)
            # Only care about the last timestep for evalutaion.
            pred_nll.update(nll.mean().item(), N)
            # Interval pred is the probability mass [t-1/2, t+1/2].
            log_tau, mu, log_sigma = m_out.unbind(dim=2)
            # y is effectively the lhs of the sample interval. For the
            # point estimate, we will use the mid-point. Although, this isn't
            # a metric that will be used for reporting, as it's not so
            # meaningful in the sampling regime.
            t_min = y
            t_max = t_min + 1
            interval_pred_nll.update(
                -prob.logmix_interval_log_prob(
                    log_tau, mu, log_sigma, t_min, t_max
                )
                .mean()
                .item(),
                N,
            )
            target = t_min + 0.5
            pred = prob.logmix_median(log_tau, mu, log_sigma)
            assert pred.shape == target.shape
            mean_abs_err.update((pred - target).abs().mean().item(), N)

        res = {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "interval_pred_nll": interval_pred_nll.avg,
            **strain_res,
        }
        return res

    def _eval_loss(self, dl, ds):
        """Evaluate loss only."""
        del ds  # Not used.
        loss_meter = kdai._logging.Meter("loss")
        for x, cell_id, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, cell_id, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl, ds):
        """Evaluate with additional metrics."""
        metrics = self.eval_metrics(dl, ds)

        def to_log_metric(m, v):
            if m == "loss":
                return kdai._logging.loss_metric(v)
            else:
                return kdai._logging.Metric(m, v, increasing=False)

        results = {
            "metrics": [
                *[to_log_metric(m, v) for m, v in metrics.items()],
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_dists(dl, max_n_plots=5)
            ),
        }
        return results

    def _eval_recs(self, ds):
        if self.eval_rec_cids is None or len(self.eval_rec_cids) == 0:
            # Just take the first cell of the first recording.
            eval_recs = [
                ds.recordings[0].cells(set(ds.recordings[0].cell_ids[0]))
            ]
        else:
            rec_to_cid = defaultdict(set)
            for r_name, c_id in self.eval_rec_cids:
                rec_to_cid[r_name].add(c_id)
            eval_recs = []
            for r in ds.recordings:
                if r.name in rec_to_cid:
                    rec = r.cells(rec_to_cid[r.name])
                    if len(rec.cell_ids) == 0:
                        raise ValueError(
                            f"Cells ({rec_to_cid[r.name]}) not in rec {r.name}"
                        )
                    eval_recs.append(rec)
        return eval_recs

    def evaluate_train(self, dl_fn):
        return self.evaluate(self.train_dl(dl_fn), self.ds_mgr.train_ds())

    def evaluate_val(self, dl_fn):
        if self.model.training:
            _logger.warning("Model should be in eval mode")
        # The ds_mgr is the BasicDatasetManager, which just wraps the dataset.
        return self.evaluate(self.val_dl(dl_fn), self.ds_mgr.val_ds())

    def evaluate(self, dl, rec):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl, rec)
        return results

    def plot_dists(self, dl, max_n_plots):
        sample = next(iter(dl))
        _, _, t = sample
        # We only care about the last timestep.
        m_out, _ = self.forward(sample)
        log_tau, mu, log_sigma = m_out.unbind(dim=2)
        return vis.logmix_dist_plots(log_tau, mu, log_sigma, t, max_n_plots)


class SpikesDiscrete(kdai.train.BaseTrainable):
    """DiscreteTrainable for spike data.

    Non-causal only (same as LogMixForSpikes).

    We don't have a bin mode here. For spikes, we use 128 bins, with the last
    bin including [127, ∞). This is dictated by the requirements of the
    prediction task, whereby we don't need any information breakdown for
    the interval [127, ∞).

    The spike train generation process that allows for the calculation of a
    spike train similarity/distance will be done in a separate python file; it's
    not something we will do while training.

    Other changes:
        - the dataset won't use padding, so pad_val and ignore_index present
          in DiscreteTrainable are not used here.

    TODO: In order to compare against LogMix, LogMix will
    need to integrate over this interval when events fall in it. I don't think
    this is done yet.

    """

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        eval_rec_cids=None,
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self.eval_rec_cids = eval_rec_cids
        if init_weights:
            self.init_weights()
        # Not great abstraction here, but we aren't exactly making a library.
        self.model_in_len = self.model.gpt_base.input_len
        # Speed-up tweak/hack.
        # The model has an out_resolution parameter—it determines the size of
        # the output layer. This trainable refers to that parameter in the
        # loss function; however, when compiling the model and the cforward
        # function (on pytorch 2.5), this reference to the out_resolution
        # parameter of the model prevents the compilation of the cforward
        # function. So, we make a record of the value here, and use it in the
        # loss function. This prevents an assertion error from dynamo.
        self.out_resolution = self.model.out_resolution

    def init_weights(self):
        """Initialize weights associated with normalization.

        Same as LogMixForSpikes.

        The input has hetrogenous data along the last dimension, and needs
        to be normalized differently."""
        self.model.set_input_mean_sd(
            torch.full((5,), 0.5),
            torch.tensor([0.5, 0.5, 0.5, 0.5, 1]),
        )

    def _ts(self, device):
        return torch.arange(self.model.out_resolution, device=device)

    def loss_fn(self, m_out, target_t, reduction="mean"):
        """
        Cross-entropy loss with clamping to within bin range.

        Args:
            m_out: logits for out_resolution bins. Bins represent a milisecond
                interval from 0 to out_resolution, and the last bin is
                [out_resolution, ∞).
            target_t (long): miliseconds until next spike. Unlike other
                trainables, we don't bother giving target_t an extra dimension,
                so it's just (B, ).
        """
        B, N = m_out.shape
        assert len(target_t.shape) == 1 and target_t.shape[0] == B
        target_t = target_t.long().cuda()
        assert torch.all(target_t >= 0), target_t
        target_t = torch.clamp(target_t, 0, self.out_resolution - 1)
        # Cross-entropy is equivalent to the negative log-likelihood.
        res = F.cross_entropy(m_out, target_t, reduction=reduction)
        return res

    def forward(self, sample):
        """Forward call that may be compiled.

        Route model calls that also need loss calculated though this fn."""
        x, cell_id, y = sample
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        m_out = self.model(x, cell_id)
        loss = self.loss_fn(m_out, y)
        return m_out, loss

    def expected_val(self, m_out):
        raise NotImplementedError(
            "Can't calc expected val as last bin is infinite"
        )
        # A lower bound would be:
        ts = self._ts(m_out.device)
        assert len(ts) == self.model.out_resolution
        probs = F.softmax(m_out, dim=1)
        expected_val = (probs * ts).sum(dim=1)
        return expected_val

    def mode(self, m_out):
        ts = self._ts(m_out.device)
        assert len(ts) == self.model.out_resolution
        probs = F.softmax(m_out, dim=1)
        mode = ts[probs.argmax(dim=1)]
        return mode

    def median(self, m_out):
        """
        Calculate the least number x that satisfies
        P(t≤x) ≥ 0.5 and P(t≥x) ≥ 0.5.
        """
        x_idx = torch.argmax(
            # argmax doesn't work for bool.
            (F.softmax(m_out, dim=1).cumsum(dim=1) >= 0.5).int(),
            dim=1,
        )
        ts = self._ts(m_out.device)
        median = ts[x_idx]
        return median

    def infer(
        self,
        m_out,
        method: Literal["expected_val", "mode", "median"] = "expected_val",
    ):
        if method == "expected_val":
            return self.expected_val(m_out)
        elif method == "mode":
            return self.mode(m_out)
        elif method == "median":
            return self.median(m_out)
        else:
            raise ValueError(f"Invalid method: {method}")

    def infer_median(self, x, cell_id):
        B, C, L = x.shape
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        m_out = self.model(x, cell_id)
        median = self.median(m_out)
        return median

    def ll(self, x, cell_id, t_min, t_max):
        B, C, L = x.shape
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        assert t_min.shape == t_max.shape == (B,), t_min.shape
        # For the DiscreteSpikes trainable, the correct bin is indexed by t_min
        assert torch.all(t_min >= 0) and torch.all(t_min < self.out_resolution)
        y = t_min.long().cuda()
        m_out = self.model(x, cell_id)
        ll = -self.loss_fn(m_out, y, reduction="none")
        assert ll.shape == (B,), ll.shape
        return ll

    def _ll(self, m_out, t_min, t_max):
        assert t_min.shape == t_max.shape == 1
        # For the DiscreteSpikes trainable, the correct bin is indexed by t_min
        assert torch.all(t_min >= 0) and torch.all(t_min < self.out_resolution)
        y = t_min.long().to(m_out.device)
        ll = -self.loss_fn(m_out, y)
        return ll

    def median_and_ll(self, x, cell_id, t_min, t_max):
        """This function is a convenient combo used by inferspikes.

        Without it, inferspikes would need to call forward twice to get
        both pieces of information.

        In future, a better API design would to be to have a function:

            forward_and(x, cell_id, y=None, fns=(,))

        That would call the fns on the output of forward and return the
        results as a list. That way, the structure of the model outputs would
        not need to be known by the caller.

        A consequence of how inferspikes works is that t_min and t_max can
        be negative when there are no ground truth spikes remaining. For
        the discrete output, these should fall into the final bin.
        """
        B, C, L = x.shape
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        m_out = self.model(x, cell_id)
        median = self.median(m_out)
        # For the DiscreteSpikes trainable, the correct bin is indexed by t_min
        # Negative t_min should have been converted [win, ∞) by inferspikes.
        assert torch.all(t_min >= 0) and torch.all(t_max > t_min)
        y = t_min.long().to(m_out.device)
        # y = torch.where(y < 0, torch.full_like(y, self.out_resolution - 1), y)
        ll = -self.loss_fn(m_out, y, reduction="none")
        return median, ll

    def eval_metrics(self, dl, ds):
        recs = self._eval_recs(ds)
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        mean_abs_err_mode = kdai._logging.Meter("mean_abs_err_mode")
        stride = self.out_resolution - 1
        assert stride == 80, stride

        # 1. Per cell metrics. Spike train metrics and loss.
        strain_res = {}
        tag_fn = lambda r, c: f"r{r.name}_c{c}" if len(recs) > 1 else f"c{c}"
        for rec in recs:
            ll = kdtpp.inferspikes.rec_ll(
                self, [rec], stride, dl.batch_size, dl.num_workers
            )
            ll = {f"ll_{tag_fn(rec, cid)}": v for cid, v in ll.items()}
            strain_res.update(ll)
            for cid in rec.cell_ids:
                r = rec.cells({cid})
                assert r.cell_ids == [cid], r.cell_ids
                van_rossum, pcorr, schreiber, auto_ll = (
                    kdtpp.inferspikes.auto_stats(
                        self, r, stride, bin_ms=1000 / 992, sigma_ms=60
                    )
                )
                strain_res[f"van_rossum_τ60_{tag_fn(r, cid)}"] = van_rossum
                strain_res[f"pcorr_σ60_{tag_fn(r, cid)}"] = pcorr
                strain_res[f"schreiber_σ60_{tag_fn(r, cid)}"] = schreiber
                strain_res[f"auto_ll_{tag_fn(r, cid)}"] = auto_ll

        for x, cell_id, y in dl:
            N = x.shape[0]
            m_out, loss = self.forward((x, cell_id, y))
            loss_meter.update(loss.item(), N)
            # Alternatively, copy loss (making sure to get N scale right).
            pred_nll.update(loss.item(), N)
            y = y.float().cuda()
            pred_mode = self.mode(m_out)
            pred_median = self.median(m_out)
            mean_abs_err.update((pred_median - y).abs().mean().item(), N)
            mean_abs_err_mode.update((pred_mode - y).abs().mean().item(), N)
        res = {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "mean_abs_err_mode": mean_abs_err_mode.avg,
            **strain_res,
        }
        return res

    def _eval_loss(self, dl, ds):
        """Evaluate loss only."""
        del ds  # Not used.
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, mask, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl, ds):
        """Evaluate with additional metrics."""
        metrics = self.eval_metrics(dl, ds)

        def to_log_metric(m, v):
            if m == "loss":
                return kdai._logging.loss_metric(v)
            else:
                return kdai._logging.Metric(m, v, increasing=False)

        results = {
            "metrics": [
                *[to_log_metric(m, v) for m, v in metrics.items()],
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_dists(dl, max_n_plots=5)
            ),
        }
        return results

    def _eval_recs(self, ds):
        if self.eval_rec_cids is None or len(self.eval_rec_cids) == 0:
            # Just take the first cell of the first recording.
            eval_recs = [
                ds.recordings[0].cells(set(ds.recordings[0].cell_ids[0]))
            ]
        else:
            rec_to_cid = defaultdict(set)
            for r_name, c_id in self.eval_rec_cids:
                rec_to_cid[r_name].add(c_id)
            eval_recs = []
            for r in ds.recordings:
                if r.name in rec_to_cid:
                    rec = r.cells(rec_to_cid[r.name])
                    if len(rec.cell_ids) == 0:
                        raise ValueError(
                            f"Cells ({rec_to_cid[r.name]}) not in rec {r.name}"
                        )
                    eval_recs.append(rec)
        return eval_recs

    def evaluate_train(self, dl_fn):
        return self.evaluate(self.train_dl(dl_fn), self.ds_mgr.train_ds())

    def evaluate_val(self, dl_fn):
        # We don't manage the model train/val state here, as it could be a
        #
        if self.model.training:
            _logger.warning("Model should be in eval mode")
        # The ds_mgr is the BasicDatasetManager, which just wraps the dataset.
        return self.evaluate(self.val_dl(dl_fn), self.ds_mgr.val_ds())

    def evaluate(self, dl, ds):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl, ds)
        return results

    def plot_dists(self, dl, max_n_plots):
        sample = next(iter(dl))
        x, cell_id, t = sample
        m_out = self.model(x.float().cuda(), cell_id.long().cuda())
        return self._plot_dists(m_out, t, max_n_plots)

    def _plot_dists(self, m_out, target_t, max_n_plots):
        """Plots a distribution as a histogram; the target is highlighted."""
        batch_size = m_out.shape[0]
        probs = F.softmax(m_out, dim=1).cpu().numpy()
        median = self.median(m_out).cpu().numpy()
        target_t = target_t.cpu().numpy()
        ts = self._ts("cpu").numpy()
        assert len(ts) == self.model.out_resolution
        figs = []
        for i in range(min(max_n_plots, batch_size)):
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.set_xlim(0, self.model.out_resolution + 1)
            ax.bar(ts, probs[i], width=1)
            ax.axvline(target_t[i], color="r", linestyle="--")
            ax.axvline(median[i], color="g", linestyle="--")
            ax.set_title(f"t={target_t[i]}")
            figs.append(fig)
        return figs


class VarBinForSpikes(kdai.train.BaseTrainable):
    """
    VarBinTrainable, but for spike data.

    Not sure if this will be needed. Never actually got to the point of using.

    Non-causal only (same as LogMixForSpikes).


    "timestamp" of the event is the 0th channel in the data. Marks, if present,
    are the remaining channels, although not used in this trainable.
    """

    def __init__(
        self,
        ds_mgr,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        init_weights=True,
    ):
        super().__init__(ds_mgr, model, label)
        self._evaluate = getattr(self, f"_eval_{eval_mode}")
        self.eval_len = eval_len
        self.loss_fn = self.loss_fn_noncausal
        # For cross_entropy, we need to ignore the padding value.
        if hasattr(self.ds_mgr, "pad_val"):
            self.ignore_index = self.ds_mgr.pad_val
        else:
            # No padding, so just use the default for F.cross_entropy.
            self.ignore_index = -100

        self._bin_edges = self.calculate_bin_edges(
            self.ds_mgr.train_dts_flat(), self.model.out_resolution
        )
        # Currently, we hardcode the exponential tail based on the second last
        # bin width. A better choice than just setting it to zero. Optionally,
        # this could be a learned parameter.
        self.last_bin_lambda = 1 / self._bin_edges[-2].item()
        self._bin_widths = self._bin_edges[1:] - self._bin_edges[:-1]
        assert type(self._bin_edges) == type(self._bin_widths) == torch.Tensor
        assert len(self._bin_edges) == len(self._bin_widths) + 1
        assert self._bin_edges[0] == 0 and self._bin_edges[-1] == np.inf
        if init_weights:
            self.init_weights()
        # Speed-up tweak/hack.
        # The model has an out_resolution parameter—it determines the size of
        # the output layer. This trainable refers to that parameter in the
        # loss function; however, when compiling the model and the cforward
        # function (on pytorch 2.5), this reference to the out_resolution
        # parameter of the model prevents the compilation of the cforward
        # function. So, we make a record of the value here, and use it in the
        # loss function. This prevents an assertion error from dynamo.
        self.out_resolution = self.model.out_resolution

    @functools.lru_cache(maxsize=2)
    def bin_edges(self, device=None):
        """Return bin edges on the desired or default device."""
        if device is None:
            device = self.in_device()
        return self._bin_edges.clone().to(device)

    @functools.lru_cache(maxsize=2)
    def bin_widths(self, device):
        """Return bin widths on the desired or default device."""
        if device is None:
            device = self.in_device()
        return self._bin_widths.clone().to(device)

    @functools.lru_cache(maxsize=2)
    def bin_centers(self, device):
        """Return bin centers on the desired or default device."""
        bin_edges = self.bin_edges(device)
        bin_centers = (bin_edges[1:] + bin_edges[:-1]) / 2
        return bin_centers

    @staticmethod
    def calculate_bin_edges(train_dts, n_bins):
        """Calculate bin edges that have equal probability mass.
        Some caveats: the first bin will extend to 0, and the last bin will
        have an exponential tail extending to infinity. This is to prevent
        obtaining 0 likelihood on unseen data. Because these bins will extend
        their probability mass over a potentially large range, we would prefer
        these bins to be assigned less mass. In effect, we want to make it
        less likely that a sample falls into these lower density bins. To that
        effect, they will be given less mass, which is currently hardcoded to
        1/4 the mass of the other bins.
        """
        edge_bins = 2
        EDGE_FRAC = 0.25
        n_center_bins = n_bins - edge_bins
        center_bin_mass = 1 / (n_center_bins + 2 * EDGE_FRAC)
        edge_bin_mass = center_bin_mass * EDGE_FRAC
        prob_boundaries = np.linspace(
            edge_bin_mass,
            1 - edge_bin_mass,
            n_center_bins + 1,
        )
        center_bin_edges = np.quantile(train_dts, prob_boundaries)
        bin_edges = np.concatenate([[0], center_bin_edges, [np.inf]])
        assert len(bin_edges) == n_bins + 1
        bin_edges = torch.tensor(
            bin_edges, dtype=torch.float32, requires_grad=False
        )
        return bin_edges

    def get_bin_idx(self, target_t):
        """
        right=True
        ----------
        right=True means that if the query t is equal to the left edge of a bucket,
        then it will be placed in that bucket. Example for when it's important:

          Case:   t = 0, bin_edges = [0, 1, 2]
          Desired result: we want the bin whose edges are [0, 1] to be selected.

        -1
        --
        bucketize will match a bin like:  edges[i-1] <= t < edges[i], and then
        return i for a match. This means that i is always 1 greater than the
        actual bin index (0-based). So, we subtract 1 to get the correct index.
        """
        res = torch.bucketize(target_t, self.bin_edges(), right=True) - 1
        return res

    def init_weights(self):
        """Initialize weights associated with normalization.

        Same as LogMixForSpikes.

        The input has hetrogenous data along the last dimension, and needs
        to be normalized differently."""
        self.model.set_input_mean_sd(
            torch.full((5,), 0.5),
            torch.tensor([0.5, 0.5, 0.5, 0.5, 1]),
        )

    def loss_fn_noncausal(self, m_out, target_t):
        assert len(target_t.shape) == 2
        assert torch.all(target_t != self.ignore_index), "Expected no padding"
        D = m_out.device
        target_t = target_t[:, 0].float().to(D)
        nll = -prob.var_bin_ll(
            m_out, self.bin_edges(D), target_t, self.last_bin_lambda
        )
        nll = torch.where(
            target_t == self.ignore_index,
            torch.tensor(0.0, device=D),
            nll,
        )
        res = nll.mean()
        return res

    def ll(self, m_out, target_t):
        """Log-likelihood."""
        B, N = m_out.shape
        assert target_t.shape == (B, 1)
        D = m_out.device
        target_t = target_t[:, 0].float().to(D)
        assert torch.all(target_t != self.ignore_index), "Expected no padding"
        res = prob.var_bin_ll(
            m_out, self.bin_edges(D), target_t, self.last_bin_lambda
        ).mean()
        return res

    def cforward(self, sample):
        """Forward call that may be compiled.

        Route model calls that also need loss calculated though this fn."""
        x, cell_id, y = sample
        x = x.float().cuda()
        cell_id = cell_id.long().cuda()
        m_out = self.model(x, cell_id)
        loss = self.loss_fn(m_out, y)
        return m_out, loss

    def forward(self, sample):
        m_out, loss = self.cforward(sample)
        return m_out, loss

    def expected_val(self, m_out):
        B, L = m_out.shape
        res = prob.var_bin_expected_val(
            m_out, self.bin_edges(m_out.device), self.last_bin_lambda
        )
        return res

    @torch.no_grad()
    def mode(self, m_out):
        """no_grad() because grad won't work anyway, as we are using argmax."""
        # B, L = m_out.shape
        # bin_widths = self.bin_widths(m_out.device)
        # lprobs = F.log_softmax(m_out, dim=1)
        # lprob_density = lprobs / bin_widths
        # max_density_bin = torch.argmax(lprob_density, dim=1)
        # mode = self.bin_centers(m_out.device)[max_density_bin]
        mode = prob.var_bin_mode(
            m_out, self.bin_edges(m_out.device), self.last_bin_lambda
        )
        return mode

    @torch.no_grad()
    def median(self, m_out):
        """
        no_grad() because grad won't work anyway, as we are using searchsorted.
        """
        return prob.var_bin_median(
            m_out, self.bin_edges(m_out.device), self.last_bin_lambda
        )

    def infer(
        self,
        m_out,
        method: Literal["expected_val", "mode", "median"] = "expected_val",
    ):
        if method == "expected_val":
            return self.expected_val(m_out)
        elif method == "mode":
            return self.mode(m_out)
        elif method == "median":
            return self.median(m_out)
        else:
            raise ValueError(f"Invalid method: {method}")

    def _interval_log_prob(self, t, m_out, interval_len):
        from_t = torch.clip(t - interval_len / 2, 0)
        to_t = from_t + interval_len
        res = torch.log(
            prob.var_bin_interval_prob(
                m_out, self.bin_edges(m_out.device), from_t, to_t
            )
        )
        if torch.any(torch.isnan(res)):
            _logger.warning("NaN in interval log prob.")
        elif torch.any(torch.isinf(res)):
            _logger.warning("Inf in interval log prob.")
        assert torch.all(res <= 0), "Probability mass must be less-equal 1."
        return res

    def eval_metrics(self, dl):
        loss_meter = kdai._logging.Meter("loss")
        pred_nll = kdai._logging.Meter("pred_nll")
        mean_abs_err = kdai._logging.Meter("mean_abs_err")
        mean_abs_err_mode = kdai._logging.Meter("mean_abs_err_mode")
        mean_abs_err_mean = kdai._logging.Meter("mean_abs_err_mean")
        interval_pred_nll = kdai._logging.Meter("interval_pred_nll")
        for X, cell_id, y in dl:
            N = X.shape[0]
            m_out, loss = self.cforward((X, cell_id, y))
            loss_meter.update(loss.item(), N)
            rescaled_ll = self.ll(m_out, y)
            pred_nll.update(-rescaled_ll.item(), N)

            y = y[:, 0].float().cuda()
            pred_exp_val = self.expected_val(m_out)
            pred_mode = self.mode(m_out)
            pred_median = self.median(m_out)
            mean_abs_err.update((pred_median - y).abs().mean().item(), N)
            mean_abs_err_mode.update((pred_mode - y).abs().mean().item(), N)
            mean_abs_err_mean.update((pred_exp_val - y).abs().mean().item(), N)
            interval_pred_nll.update(
                -self._interval_log_prob(
                    y, m_out, self.ds_mgr.density_interval_len
                )
                .mean()
                .item(),
                N,
            )
        return {
            "loss": loss_meter.avg,
            "pred_nll": pred_nll.avg,
            "mean_abs_err": mean_abs_err.avg,
            "mean_abs_err_mode": mean_abs_err_mode.avg,
            "mean_abs_err_mean": mean_abs_err_mean.avg,
            "interval_pred_nll": interval_pred_nll.avg,
        }

    def _eval_loss(self, dl):
        """Evaluate loss only."""
        loss_meter = kdai._logging.Meter("loss")
        for x, mask, y in dl:
            N = x.shape[0]
            _, loss = self.forward((x, mask, y))
            loss_meter.update(loss.item(), N)
        results = {
            "metrics": [
                kdai._logging.loss_metric(loss_meter.avg),
            ],
        }
        return results

    def _eval_info(self, dl):
        metrics = self.eval_metrics(dl)

        results = {
            "metrics": [
                kdai._logging.loss_metric(metrics["loss"]),
                kdai._logging.Metric(
                    "mean_abs_err", metrics["mean_abs_err"], increasing=False
                ),
                kdai._logging.Metric(
                    "mean_abs_err_mean",
                    metrics["mean_abs_err_mean"],
                    increasing=False,
                ),
                kdai._logging.Metric(
                    "mean_abs_err_mode",
                    metrics["mean_abs_err_mode"],
                    increasing=False,
                ),
                kdai._logging.Metric(
                    "pred_nll", metrics["pred_nll"], increasing=False
                ),
            ],
            "figs": kdai._logging.MplFigureList(
                self.plot_dists(dl, max_n_plots=5)
            ),
        }
        return results

    def evaluate(self, dl):
        dl = kdai.datasets.ConstrainedIterable(dl, self.eval_len)
        results = self._evaluate(dl)
        return results

    def plot_dists(self, dl, max_n_plots):
        sample = next(iter(dl))
        X, mask, t = sample
        m_out = self.model(X.float().cuda(), mask.float().cuda())
        return self._plot_dists(m_out, t, max_n_plots)

    def _plot_dists(self, m_out, target_t, max_n_plots):
        """Plots a distribution as a histogram; the target is highlighted."""
        batch_size = m_out.shape[0]
        probs = F.softmax(m_out, dim=1).cpu().numpy()
        target_t = target_t.cpu().numpy()
        figs = []
        edges = self.bin_edges("cpu").numpy()
        widths = edges[1:] - edges[0:-1]
        # Dont' use infinity for last bin. Just twice the previous bin.
        widths[-1] = widths[-2] * 2
        for i in range(min(max_n_plots, batch_size)):
            fig = mpl.figure.Figure()
            ax = fig.subplots()
            ax.bar(
                edges[:-1],
                probs[i],
                width=widths,
                align="edge",
                edgecolor="black",
            )
            ax.set_xlim(0, edges[-2] + widths[-2])
            ax.set_title(f"t={target_t[i]}")
            figs.append(fig)
        return figs
