import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
import kdai
import kdai.nn
import math
from typing import Tuple
import logging
import kdai.train

_logger = logging.getLogger(__name__)


class RMTPP(nn.Module):
    def __init__(self, n_q, n_h, n_unroll=10):
        """
        Implements the model described in
            "Recurrent Marked Temporal Point Processes: Embedding Event History
            to Vector" by Du et al. (2016)

        The paper for this model specifies using truncated backpropagation
        through time (TBPTT), but it doesn't specify the number of unroll steps.
        Inspecting the repository reveals that the unroll length is varied
        depending on the dataset. For example:

        The `bptt` argument is set to:

            - synthetic datasets: 1 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/scripts/synthetic_run.sh)
            - taxi dataset: 10 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/scripts/taxi_run.sh#L16)
            - stock orders: 3 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/scripts/book_order.sh)
            - stackoverflow dataset: 6 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/monkeys/so_server.sh#L16)
            - MIMIC II: 2 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/scripts/mimic2_run.sh)

        These datasets are used in the repo but not reported in the paper:

            - lastfm: 8 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/scripts/lastfm_run.sh)
            - sns: 3 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/scripts/sns.sh)
            - ali: 1 (https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/nn/scripts/ali_run.sh)

        The bptt argument is given as the `unroll` argument to the `Run` method,
        so it seems very likely that this parameter is the number of unroll steps.

        Run method: https://github.com/dunan/NeuralPointProcess/blob/ce05ec14624d9127a978bf1e782bd459a77568cb/code/network/src/main.cpp#L100

        Args:
            n_q: number of event types.
            n_h: length of t embedding and length of hidden state.
            n_unroll: number of steps to unroll the RNN.
        """
        super().__init__()
        self.n_q = n_q
        self.n_h = n_h
        self.n_unroll = n_unroll
        self.w = nn.Parameter(torch.tensor(1.0))
        self.b = nn.Parameter(torch.tensor(0.0))
        self.t_embed = nn.Linear(1, self.n_h)
        self.q_embed = nn.Embedding(self.n_q, self.n_h)
        # Not mentioned in RMTPP paper, but used here as it's used elsewhere
        # and has significant impact on training.
        self.input_norm = kdai.nn.Normalize()
        self.rnn = nn.RNN(
            input_size=self.n_h,
            hidden_size=self.n_h,
            num_layers=1,
            batch_first=True,
            nonlinearity="relu",
        )
        self.t_out = nn.Linear(self.n_h, 1, bias=True)

    def forward(self, dt_seq):
        # currently doesn't support marks (dt_seq is shape (b, s, 1)).
        x = self.input_norm(dt_seq)
        t_embed = self.t_embed(dt_seq)
        t_nograd = t_embed[:, 0 : -self.n_unroll]
        t_hasgrad = t_embed[:, -self.n_unroll :]
        # q_embed = self.q_embed(q_seq)
        _, h = self.rnn(t_nograd)  # + q_embed)
        h = h.detach()
        x, _ = self.rnn(t_hasgrad, h)  # + q_embed)
        x = x[:, -1]
        x = self.t_out(x)
        x = einops.rearrange(x, "b 1 -> b")
        return x


"""
Pytorch implementation of the common model used by Omi et al. (2019).

Omi et al. (2019) used a model description that has a configurable 
hazard layer. The single class in their model handles:

    - RNN to constant hazard function 
    - RNN to exponential hazard function
    - RNN to piecewise-constant hazard function
    - RNN that parameterizes an implicit hazard function (defined by a 
        2nd neural network)

Notes:

  - model handles input normalization (RMTPP doesn't seem to)
  - there is no input embedding step (inputs are scalars).
  - activation function is tanh.
  - event types are not supported 
    input-only event types.
  - the input sequence length must match the unroll length 

"""


class OmiConstant(nn.Module):
    """Model for the constant hazard function, following Omi et al. (2019).

    The hazard function is:

        h(t) = e^{m_out}

    Corresponding to an exponential distribution with rate parameter exp(m_out).

    Considering the network's last operation (bias add), the hazard can be
    thought of as:

        h(t) = e^{m_out + d}

           m_out  d
              │   │
              └───+──> log_intensity

    """

    def __init__(self, rnn_stem):
        super().__init__()
        self.rnn = rnn_stem
        self.input_norm = kdai.nn.NormMask()
        self.fc = nn.Linear(rnn_stem.n_h, 1)

    @torch.no_grad()
    def init_output(self, d: float):
        """Initialize to match the data.

        Assuming m_out starts near zero, choose d such that the exponential
        distribution effectively covers the time deltas in the data.

        One possibility is to set d to be log(1/(data mean)). Note however that
        the exponential distribution is highly skewed, and if the distribution's
        parameter λ=1/a then the variance will be 1/a².
        """
        _logger.info(f"Const hazard initialization. {d=:.3g}")
        # within no_grad(), tensors are still created with requires_grad=True.
        self.fc.bias.fill_(d)

    def forward(self, seq, mask):
        x = self.input_norm(seq, mask)
        x = self.rnn(x)
        x = self.fc(x)
        x = einops.rearrange(x, "b 1 -> b")
        return x


class OmiExponential(nn.Module):
    """Model for the exponential hazard function, following Omi et al. (2019).

    The hazard function is:

            h(t) = e^{at + m_out}

    Considering the network's last operation (bias add), the hazard can be
    thought of as:

            h(t) = e^{at + m_out + d}

          a ──┐ m_out ─┐ d ─┐
              │        │    │
         dt ──*────────+────+───> log_intensity

    In this form, we can see the equivalence to the RMTPP model.

    Where d is the bias term and m_out is just before adding the bias. This
    perspective is useful when initializing the model.
    """

    def __init__(self, rnn_stem):
        super().__init__()
        self.input_norm = kdai.nn.NormMask()
        self.rnn = rnn_stem
        self.fc = nn.Linear(rnn_stem.n_h, 1)
        # Basic initialization. init() should be called with a better value.
        self.a = nn.Parameter(torch.tensor(-1.0))

    @torch.no_grad()
    def init_output(self, a: float, d: float):
        """Initialize to match the data.

        We can think of the output hazard function as:

            h(t) = e^{at + m_out* + d}

        Where d is the bias term and m_out is just before adding the bias.
        With the assumption that m_out is around zero, choose a and d such that
        the Gompertz distribution has roughly the right shape and scale for the
        data (the ts). Once this is done, call this method with the chosen
        values.
        """
        _logger.info(f"Exponential hazard initialization. {a=:.3g}, {d=:.3g}")
        self.a.fill_(a)
        self.fc.bias.fill_(d)

    def forward(self, seq, mask):
        x = self.input_norm(seq, mask)
        x = self.rnn(x)
        x = self.fc(x)
        x = einops.rearrange(x, "b 1 -> b")
        return x


class OmiNN(nn.Module):
    def __init__(self, rnn_stem, n_h, n_layers):
        super().__init__()
        self.rnn = rnn_stem
        self.n_h = n_h
        self.n_layers = n_layers
        self.input_norm = kdai.nn.NormMask()
        # Original paper uses bias=False. Why?
        # EasyTPP uses bias: https://github.com/ant-research/EasyTemporalPointProcess/blob/main/easy_tpp/model/torch_model/torch_fullynn.py
        self.fc_t = nn.Linear(1, self.n_h, bias=False)
        # self.fc_t = nn.Linear(1, self.n_h, bias=True)
        self.fc_h = nn.Linear(self.rnn.n_h, self.n_h, bias=True)
        self.mids = nn.ModuleList(
            [
                nn.Linear(self.n_h, self.n_h, bias=True)
                for _ in range(self.n_layers)
            ]
        )
        self.int_hazard_fc = nn.Linear(self.n_h, 1)
        self._init_weights()

    @torch.no_grad()
    def init_output(self, in_gain):
        """Currently not used."""
        self.int_hazard_fc.bias.fill_(math.log(2))
        self.fc_t.weight.data *= in_gain

    @torch.no_grad()
    def _init_weights(self):
        """
        This is the initialization protocol described in Omi et al. (2019).
        """

        def positive_xavier_uniform(m, gain=1.0):
            if type(m) == nn.Linear:
                fan_out, fan_in = m.weight.shape
                nn.init.uniform_(
                    m.weight,
                    0,
                    gain * math.sqrt(6.0 / (fan_in + fan_out)),
                )
                m.weight.data.abs_()
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            else:
                assert (
                    type(m) == nn.ModuleList
                ), f"Unexpected module type ({type(m)})"

        # The default gain regularly caused very high initial loss.
        # Some trial and error landed us on a reduction to 1/3 that performs
        # seems to perform better across the Badges, Cyclic and Rand Proc
        # datasets. One possible reason for this is that PyTorch does not
        # initialise RNNs well by default.
        positive_xavier_uniform(self.fc_t, gain=1 / 3)
        positive_xavier_uniform(self.fc_h, gain=1 / 3)
        for m in self.mids:
            positive_xavier_uniform(m, gain=1 / 3)
        positive_xavier_uniform(self.int_hazard_fc)

    def clamp_weights(self):
        # Clamp (with epsilon for stability).
        # Omi et al. (2019) do this after a gradient update; however, in Pytorch
        # there isn't any post-update hook.
        # Epsilon was added, following the advice from https://github.com/wassname/torch-neuralpointprocess/tree/master
        eps = 1e-10
        self.fc_t.weight.data.clamp_(min=eps)
        self.fc_h.weight.data.clamp_(min=eps)
        for m in self.mids:
            m.weight.data.clamp_(min=eps)
        self.int_hazard_fc.weight.data.clamp_(min=eps)

    @torch.enable_grad()
    def hazard(self, rnn_out, t):
        """
        We need to track gradients even when in eval mode, hence
        torch.enable_grad().
        """
        # t, if loaded from a dataloader, will not have a Jacobian computed. As
        # we need to calculate the gradient of the cumulative hazard wrt t, we
        # need to ensure that requires_grad is set to True.
        t.requires_grad_(True)
        # Careful with normalizing `t`. Without nomalization, training will
        # likely encounter numerical issues with NaNs. However, normalizing 
        # to mean 0 and variance 1 may be too aggressive in the sense that 
        # it may be preferable for t to vary over a wider range.
        tt = self.input_norm.norm(torch.clone(t))
        # Or:
        #tt = torch.clone(t)
        # not just training, as the last update step can cause negatives.
        # if self.training:
        self.clamp_weights()
        tt = einops.rearrange(tt, "b -> b 1")
        fc_t = self.fc_t(tt)
        rnn_out = self.fc_h(rnn_out)
        # Divide by 2 to maintain variance. Both are assumed to have ~the same.
        x = torch.tanh((fc_t + rnn_out) / 2.0)

        def mid(x):
            for m in self.mids:
                x = torch.tanh(m(x))
            return x

        x = mid(x)
        x = self.int_hazard_fc(x)
        int_h = F.softplus(x)
        int_h = einops.rearrange(int_h, "b 1 -> b")
        # Must be sum in order to maintain the magnitude of the gradient.
        # retain_graph defaults to the value of create_graph. Retain graph
        # must be True for us here, as we do not want the grad call to cause
        # parameters to free their gradients, which is normally done by grad().
        # So setting retain_graph=True isn't necessary, but it's explicit.
        h = torch.autograd.grad(
            int_h.sum(), t, create_graph=True, retain_graph=True
        )
        eps = 1e-10
        log_h = torch.log(h[0] + eps)
        assert torch.isfinite(log_h).all(), log_h
        assert torch.isfinite(int_h).all(), int_h
        assert log_h.shape == int_h.shape
        return log_h, int_h

    def forward(self, x, mask, t):
        x = self.input_norm(x, mask)
        assert x.shape[-1] == 2, "Expected (b s 2) (timestamps and mask)."
        x = self.rnn(x)
        log_h, int_h = self.hazard(x, t)
        return x, log_h, int_h


class OmiRNN(nn.Module):
    """Omi et al.'s (2019) RNN.
    This is shared by all models defined by Omi et al. (2019).
    """

    def __init__(self, n_h, n_in=2, n_unroll=10):
        """
        Model argument conversion (Omi el. al -> here):
            size_rnn -> n_h
            time_step -> n_unroll
            log_mode -> NA (do this in the Trainable)

        The paper's model sends time directly to the RNN, as a 1D tensor.
        As we will use this class for both the standard implementation and
        some derivative implementations, which can take embedded inputs, we
        have an `n_in` argument to specify the number of input dimensions.

        Omi et al. do not use any mask or padding—they use a model_in_len+1
        length sliding window of the input, and all forward passes contain
        complete context and 1 element output. This might be fine for sequences
        that are very long in comparison to the model input length, but for
        shorter sequences, much of the input will not have a corresponding
        prediction.

        Args:
            n_q: number of event types (>1 is not supported).
            n_h: length of t embedding and length of hidden state.
            n_in: number of input dimensions. Default is 2, which corresponds
                to time and a mask.
            n_unroll: number of steps to unroll the RNN.
        """
        super().__init__()
        self.n_h = n_h
        self.n_in = n_in
        self.n_unroll = n_unroll
        self.rnn = nn.RNN(
            input_size=self.n_in,  # No input embedding in the paper (uses 1).
            hidden_size=self.n_h,
            num_layers=1,
            batch_first=True,
            nonlinearity="tanh",
        )

    def forward(self, x):
        """
        Args:
            x: tensor of shape (batch, seq_len, channel).
        """
        b, t, c = x.shape
        if t != self.n_unroll:
            raise ValueError(
                f"Input sequence length ({t}) must match unroll length "
                f"({self.n_unroll})."
            )
        x, _ = self.rnn(x)
        x = x[:, -1]
        return x


class ShchurLogMix(nn.Module):
    """RNN-based TPP model for marked and unmarked event sequences."""

    def __init__(self, n_marks=0, n_h=32, n_mix: int = 16):
        super().__init__()
        self.n_marks = n_marks
        self.input_norm = kdai.nn.NormMask()
        self.n_h = n_h
        self.n_mix = n_mix
        if self.n_marks > 0:
            self.mark_embed = nn.Embedding(self.n_marks, self.n_h)
            # A bit of a strange way of doing things. I'd prefer to make an
            # encoding of the time deltas (e.g. positional encoding) then add
            # them. But Shchur make a 32+1 dim input, with time deltas being
            # raw.
            self.rnn_input_dim = 2 + self.n_h
        else:
            # 2: time delta and mask
            self.rnn_input_dim = 2
        # Unused
        self.context_init = nn.Parameter(
            torch.zeros(n_h)
        )  # initial state of the RNN
        self.rnn = nn.GRU(
            input_size=self.rnn_input_dim,
            hidden_size=self.n_h,
            batch_first=True,
        )
        param_per_mix = 3
        self.fc = nn.Linear(self.n_h, n_mix * param_per_mix)

    def forward(
        self, x, mask
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        b, s, m = x.shape
        xt, xm = torch.split(x, [1, m - 1], dim=-1)
        xt = self.input_norm(xt, mask)
        if self.n_marks > 0:
            xm = einops.reduce(
                self.mark_embed(xm.long()),
                "b s m e -> b s e",
                "sum",
                b=b,
                s=s,
                e=self.n_h,
            )
            rnn_input = torch.cat([xt, xm], dim=-1)
        else:
            rnn_input = xt
        x, _ = self.rnn(rnn_input)
        x = x[:, -1]
        mix_params = self.fc(x)
        log_tau, mu, log_sigma = mix_params.chunk(3, dim=1)
        return log_tau, mu, log_sigma


class BaseRNN(nn.Module):
    """Base RNN model."""

    def __init__(self, n_in, n_h=64):
        super().__init__()
        self.n_h = n_h
        self.n_in = n_in
        self.rnn = nn.GRU(
            input_size=self.n_in,
            hidden_size=self.n_h,
            batch_first=True,
        )

    def forward(self, x):
        b, s, c = x.shape
        x, _ = self.rnn(x)
        x = x[:, -1]
        return x


class RnnCat(nn.Module):
    """RNN base with discrete head."""

    def __init__(self, rnn_stem, out_resolution=128):
        super().__init__()
        self.input_norm = kdai.nn.NormMask()
        self.rnn = rnn_stem
        # 2: time delta and mask
        self.rnn_input_dim = 2
        self.out_resolution = out_resolution
        self.fc = nn.Linear(self.rnn.n_h, self.out_resolution)

    def forward(self, x, mask):
        b, s, c = x.shape
        assert c == 1, "Just time deltas expected."
        x = self.input_norm(x, mask)
        x = self.rnn(x)
        x = self.fc(x)
        return x



class NoContextDiscrete(nn.Module):
    """Outputs the same discrete distribution (learnable) for all input."""

    def __init__(self, out_resolution: int = 128):
        super().__init__()
        self.out_resolution = out_resolution
        self.dist_logits = nn.Parameter(torch.zeros(out_resolution))
        nn.init.normal_(self.dist_logits)

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        with_decay = []
        no_decay = self.parameters()
        return [with_decay, no_decay]

    def forward(self, x):
        b, s, c = x.shape
        res = einops.repeat(self.dist_logits, "r -> b s r", b=b, s=s)
        return res


class GptWithHead(nn.Module):
    def __init__(self, gpt_stem, head):
        super().__init__()
        self.gpt_stem = gpt_stem
        self.head = head
        # Pass through the input normalization and output initialization.
        self.init_input_scale = self.gpt_stem.init_input_scale
        self.init_output = self.head.init_output

    def forward(self, x, mask):
        return self.head(self.gpt_stem(x, mask))

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        res = kdai.nn.get_optim_param_groups(self)
        return res

    def __getattr__(self, name):
        """Forward attribute access to the head."""
        try:
            return super().__getattr__(name)
        except AttributeError:
            # Forward attribute access to the head.
            return getattr(self.head, name)


class GptNNHead(nn.Module):
    """Sadly, the NN head needs a different interface."""

    def __init__(self, gpt_stem, head):
        super().__init__()
        self.gpt_stem = gpt_stem
        self.head = head
        # Pass through the input normalization. NNHead doesn't have an output
        # initialization.
        self.init_input_scale = self.gpt_stem.init_input_scale
        self.init_query_scale = self.head.init_query_scale

    def forward(self, x, mask, y):
        return self.head(self.gpt_stem(x, mask), y)

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        with_grad = {
            n: p for n, p in self.named_parameters() if p.requires_grad
        }
        with_decay = [p for p in with_grad.values() if p.dim() > 1]
        no_decay = [p for p in with_grad.values() if p.dim() <= 1]
        return [with_decay, no_decay]


class GPTv2Stem(nn.Module):
    """Created so that the same GPT model could used with various heads.

    Causal only.
    """

    def __init__(
        self,
        input_len,
        n_layer,
        n_head,
        head_dim,
    ):
        super().__init__()
        self.input_len = input_len
        self.n_embd = n_head * head_dim
        self.idx_embed = kdai.nn.IdxEmbed(self.input_len, self.n_embd)
        self.t_embed = kdai.nn.ValueEmbed(self.n_embd)
        self.transformer = kdai.nn.Transformer(
            dim=self.n_embd,
            depth=n_layer,
            heads=n_head,
            mlp_dim=4 * self.n_embd,
            dropout=0.1,
            is_causal=True,
        )

    @torch.no_grad()
    def init_input_scale(self, t_max_range, t_epsilon):
        _logger.info(
            f"Initializing t_embed weights. {t_max_range=}, {t_epsilon=}"
        )
        self.t_embed.init_weights(t_max_range, t_epsilon)

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        res = kdai.nn.get_optim_param_groups(self)
        return res

    @torch.no_grad()
    def to_t_embed(self, x):
        x = torch.cumsum(x, dim=1)
        t_embd = self.t_embed(x)
        return t_embd

    def forward(self, x, mask):
        b, s, c = x.shape
        # (B, S, E) <- (B, S, E) + (S, E)
        x = self.to_t_embed(x) + self.idx_embed()
        x = self.transformer(x, mask)
        return x


class GPTStem(nn.Module):
    """Created so that the same GPT model could used with various heads.

    Causal only.
    """

    def __init__(
        self,
        input_len,
        n_layer,
        n_head,
        head_dim,
    ):
        super().__init__()
        self.input_len = input_len
        self.n_embd = n_head * head_dim
        # self.input_norm = kdai.nn.Normalize()
        self.idx_embed = nn.Embedding(self.input_len, self.n_embd)
        self.transformer = kdai.nn.Transformer(
            dim=self.n_embd,
            depth=n_layer,
            heads=n_head,
            mlp_dim=4 * self.n_embd,
            dropout=0.1,
            is_causal=True,
        )
        # A learnable slope which we will do a best effort initialization.
        if self.n_embd % 2 != 0:
            raise ValueError("n_embd must be even.")
        self.slope = nn.Parameter(torch.zeros(self.n_embd // 2))
        self.register_buffer(
            "slope_cache",
            # 1 -> 1/10,000. Linearly in log space over n_h//2 steps.
            # Note that this differs slightly from the THP implementation
            # that uses [sin(0), cos(1), sine(2),...cos(N-1)], assuming N is
            # even. Here we keep to another common practice of [sin(0), sin(1),
            # ...sin(N//2), cos(0), cos(1),...cos(N//2)].
            torch.exp(
                -torch.linspace(
                    0, math.log(10000), self.n_embd // 2, dtype=torch.float32
                )
            ),
            persistent=False,
        )

    @torch.no_grad()
    def init_input_scale(self, max_dt, min_dt):
        """
        A single timestep will be scaled by n_embd//2 values. The lowest scale
        should be small enough to allow (t_max*scale_min) to be close to but
        less than 2π. If the lowest scale is too high, then large time deltas
        can wrap around the circle. It should be large enough such that a
        decent range of [0, 2π] gets used and isn't wasted.
        A good heuristic is for dt_max*_scale should land around π, in [0, 2π],
        Around π, or later, I guess, but not more than 2π.

        For small time deltas that are just large enough to be considered
        meaningful—the highest scale should be small enough so that the
        minimum discernible difference doesn't wrap around the circle but
        large enough so that a decent range of [0, 2π] gets used. A good 
        heuristic is for dt_min*max_scale should land around π, in [0, 2π].

        If we stick to the common interface of taking mean and sd, then we
        have to express the minimum and maximum scales in terms of the mean
        and sd. It would probably be better to have a separate method like
        set_input_range(min_resolution, max_resolution). This would require
        quite a few changes elsewhere: data managers would be required to
        provide this information. Maybe they would end up using the mean,
        sd, and min and max values. So maybe just extending the interface
        to be something like: init_input_scale(mean, sd, min, max) would be
        better.

        Things get even more complicated when we realize that we convert
        relative time deltas via cumsum to time deltas from t=0. This means
        that the maximum value we need to consider is also a function of
        the sequence length. This again supports the idea that we should
        consider this initialization as being a question of minimum and
        maximum time scales that we consider meaningful. And this will be
        knowable only really by the data manager.

        We could have done something like:
          - max_dt = self.input_len * sd * 2
          - min_dt = sd / 100
        Issues include:
            - the max_dt is likely going to be way over estimated when the input
              length is large.
            - the min_dt might either be too small or too large; there really
              isn't a good way to know.
        """
        # Use up most of the range [0, 2π] by landing on 3/4 * 2π.
        min_scale = (3 / 2 * math.pi) / max_dt
        max_scale = (3 / 2 * math.pi) / min_dt

        # Range over standard or log space? Yes. We often see ranges like:
        # [1e-5, 3000], and we want to dedicate more resolution to the lower
        # end.
        # Note, if the slope is ever made to just be a buffer instead of a
        # parameter, make sure to call linspace with requires_grad=False.
        # Even though we are in a torch.no_grad scope, a tensor created with
        # requires_grad=True will still have requires_grad=True, and
        # requires_grad=True is the default.
        self.slope.data = torch.exp(
            torch.linspace(
                math.log(min_scale),
                math.log(max_scale),
                self.n_embd // 2,
                dtype=torch.float32,
            )
        )
        assert self.slope.requires_grad is True

        # # Tensor tags for debugging.
        # self.x_tag = kdai.train.TensorTag("0.0 x")
        # self.mask_tag = kdai.train.TensorTag("0.1 mask")
        # self.t_embed_tag = kdai.train.TensorTag("1.0 t_embed")
        # self.idx_embed_tag = kdai.train.TensorTag("1.1 idx_embed")
        # self.embed_tag = kdai.train.TensorTag("1.2 embed")
        # self.out_tag = kdai.train.TensorTag("2.0 out")
        # self.slope_tag = kdai.train.TensorTag("w 1.0 slope")

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        with_grad = {
            n: p for n, p in self.named_parameters() if p.requires_grad
        }
        with_decay = [p for p in with_grad.values() if p.dim() > 1]
        no_decay = [p for p in with_grad.values() if p.dim() <= 1]
        return [with_decay, no_decay]

    def to_t_embed(self, x):
        x = torch.cumsum(x, dim=1)
        # x, having last dimension 1 will broadcast with half_nch.

        # TODO: for legacy runs that haven't been evaluated yet, you need to
        # switch these lines:
        # old:
        LEGACY = False
        if LEGACY:
            t_embd = self.slope_cache * x
        else:
            t_embd = self.slope * x

        t_embd = torch.cat([t_embd.sin(), t_embd.cos()], dim=-1).to(
            dtype=torch.float32
        )
        return t_embd

    def forward(self, x, mask):
        b, s, c = x.shape
        # (B, S, E) <- (B, S, E) + (S, E)
        # x = self.to_t_embed(x) + self.idx_embed(torch.arange(s).to(x.device))
        t_embed = self.to_t_embed(x)
        idx_embed = self.idx_embed(torch.arange(s).to(x.device))
        x = t_embed + idx_embed

        # # Tags
        # self.slope_tag(self.slope.data)
        # mask = self.mask_tag(mask)
        # t_embed = self.t_embed_tag(t_embed)
        # idx_embed = self.idx_embed_tag(idx_embed)
        # x = self.embed_tag(x)

        # Mask is expected to be broadcastable to (B, e, target_len, source_len)
        # TODO: check we have the correct shape.
        # mask = einops.rearrange(mask, "b t -> b 1 1 t")
        x = self.transformer(x, mask)

        # # Tags
        # x = self.out_tag(x)
        return x


class LogMixHead(nn.Module):
    """LogMix head."""

    def __init__(self, n_in, n_mix: int = 16):
        super().__init__()
        self.n_in = n_in
        self.n_mix = n_mix
        param_per_mix = 3
        self.fc = nn.Linear(self.n_in, self.n_mix * param_per_mix)
        self.out_offset = nn.Parameter(torch.tensor(0.0))
        self.out_scale = nn.Parameter(torch.tensor(1.0))
        # # Tensor tags for debugging.
        self.last_x = kdai.train.TensorTag("0.0 x")
        self.fc_tag = kdai.train.TensorTag("1.0 fc W (param)")
        self.fc_bias_tag = kdai.train.TensorTag("1.1 fc bias (param)")
        self.out_offset_tag = kdai.train.TensorTag("1.2 out_offset (param)")
        self.out_scale_tag = kdai.train.TensorTag("1.3 out_scale (param)")
        self.log_tau_tag = kdai.train.TensorTag("2.0 log_tau")
        self.mu_tag = kdai.train.TensorTag("2.1 mu")
        self.log_sigma_tag = kdai.train.TensorTag("2.2 log_sigma")

    @torch.no_grad()
    def init_output(self, mean_log, sd_log):
        """Set output layer's normalization parameters.

        LogMix has basic normalization:
          - out_mu => out_mu + mean    | making out_mu ~ 0
          - out_sigma => out_sigma + sd | making out_sigma ~ 0
        The output sigma is in log space, so we add sd to it to make it ~0.
        """
        _logger.info(
            "LogMix output initialization. " f"{mean_log=:.3g}, {sd_log=:.3g}"
        )
        # within no_grad(), tensors are still created with requires_grad=True.
        self.out_offset.fill_(mean_log)
        self.out_scale.fill_(sd_log)
        # For initial values, spread the mixtures out a bit, using bias?

    def forward(self, x):
        b, s, c = x.shape
        x = self.last_x(x)
        mix_params = self.fc(x)
        log_tau, mu, log_sigma = mix_params.chunk(3, dim=-1)

        self.fc_tag(self.fc.weight)
        self.fc_bias_tag(self.fc.bias)
        self.log_tau_tag(log_tau)
        self.mu_tag(mu)
        self.log_sigma_tag(log_sigma)
        self.out_offset_tag(self.out_offset)
        self.out_scale_tag(self.out_scale)

        # Output normalization.
        mu = mu + self.out_offset
        log_sigma = log_sigma + self.out_scale
        return log_tau, mu, log_sigma


class ConstHazard(nn.Module):
    """Const hazard with a fully connected layer."""

    def __init__(self, n_in):
        super().__init__()
        self.n_in = n_in
        self.fc = nn.Linear(n_in, 1)
        # Tags
        self.fc_tag = kdai.train.TensorTag("2. fc out")
        self.fc_bias_tag = kdai.train.TensorTag("w 2.0 fc bias")
        self.fc_weight_tag = kdai.train.TensorTag("w 2.1 fc weight")

    @torch.no_grad()
    def init_output(self, d: float):
        """Initialize to match the data.

        Assuming m_out starts near zero, choose d such that the exponential
        distribution effectively covers the time deltas in the data.

        One possibility is to set d to be log(1/(data mean)). Note however that
        the exponential distribution is highly skewed, and if the distribution's
        parameter λ=1/a then the variance will be 1/a².
        """
        _logger.info(f"Const hazard initialization. {d=:.3g}")
        self.fc.bias.fill_(d)

    def forward(self, x):
        b, s, c = x.shape
        assert c == self.n_in, f"{c=} != {self.n_in=}"
        x = einops.rearrange(x, "b s c-> (b s) c", b=b, s=s, c=c)
        x = self.fc(x)
        x = einops.rearrange(x, "(b s) 1 -> b s", b=b, s=s)
        self.fc_tag(x)
        self.fc_bias_tag(self.fc.bias)
        self.fc_weight_tag(self.fc.weight)
        x = x
        return x


class ExpHazard(nn.Module):
    """Exponential hazard with a fully connected layer."""

    def __init__(self, n_in):
        super().__init__()
        self.n_in = n_in
        self.fc = nn.Linear(n_in, 1)
        # Basic init. init_output() should be called with a better value.
        self.a = nn.Parameter(torch.tensor(-1.0))

    @torch.no_grad()
    def init_output(self, a: float, d: float):
        """Initialize to match the data.

        We can think of the output hazard function as:

            h(t) = e^{at + m_out* + d}

        Where d is the bias term and m_out is just before adding the bias.
        With the assumption that m_out is around zero, choose a and d such that
        the Gompertz distribution has roughly the right shape and scale for the
        data (the ts). Once this is done, call this method with the chosen
        values.
        """
        _logger.info(f"Exponential hazard initialization. {a=:.3g}, {d=:.3g}")
        self.a.fill_(a)
        self.fc.bias.fill_(d)

    def forward(self, x):
        b, s, c = x.shape
        x = self.fc(x)
        x = einops.rearrange(x, "b s 1 -> b s", b=b, s=s)
        return x


class NNHazard(nn.Module):
    """NN intensity head for the GPT model (or any causal training)."""

    def __init__(self, n_in, n_h, n_layers):
        super().__init__()
        self.n_h = n_h
        self.n_in = n_in
        self.n_layers = n_layers
        self.query_norm = kdai.nn.Normalize()
        # Original paper uses bias=False. Why?
        # EasyTPP uses bias: https://github.com/ant-research/EasyTemporalPointProcess/blob/main/easy_tpp/model/torch_model/torch_fullynn.py
        self.fc_t = nn.Linear(1, self.n_h, bias=False)
        self.fc_h = nn.Linear(self.n_in, self.n_h, bias=True)
        self.mids = nn.ModuleList(
            [
                nn.Linear(self.n_h, self.n_h, bias=True)
                for _ in range(self.n_layers)
            ]
        )
        self.int_hazard_fc = nn.Linear(self.n_h, 1)
        self._init_weights()

    @torch.no_grad()
    def init_query_scale(self, mu, sd):
        """Initial values scale the query time to a reasonable range."""
        self.query_norm.set_mean_sd(mu, sd)

    @torch.no_grad()
    def _init_weights(self):
        """
        This is the initialization protocol described in Omi et al. (2019).
        """

        def positive_xavier_uniform(m, gain=1.0):
            if type(m) == nn.Linear:
                fan_out, fan_in = m.weight.shape
                nn.init.uniform_(
                    m.weight,
                    0,
                    gain * math.sqrt(6.0 / (fan_in + fan_out)),
                )
                m.weight.data.abs_()
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            else:
                assert (
                    type(m) == nn.ModuleList
                ), f"Unexpected module type ({type(m)})"

        # The default gain regularly caused very high initial loss.
        # Some trial and error landed us on a reduction to 1/3 that performs
        # seems to perform better across the Badges, Cyclic and Rand Proc
        # datasets. One possible reason for this is that PyTorch does not
        # initialise RNNs well by default.
        positive_xavier_uniform(self.fc_t, gain=1 / 3)
        positive_xavier_uniform(self.fc_h, gain=1 / 3)
        for m in self.mids:
            positive_xavier_uniform(m, gain=1 / 3)
        positive_xavier_uniform(self.int_hazard_fc)

    def clamp_weights(self):
        # Clamp (with epsilon for stability).
        # Omi et al. (2019) do this after a gradient update; however, in Pytorch
        # there isn't any post-update hook.
        # Epsilon was added, following the advice from https://github.com/wassname/torch-neuralpointprocess/tree/master
        eps = 1e-10
        self.fc_t.weight.data.clamp_(min=eps)
        self.fc_h.weight.data.clamp_(min=eps)
        for m in self.mids:
            m.weight.data.clamp_(min=eps)
        self.int_hazard_fc.weight.data.clamp_(min=eps)

    @torch.enable_grad()
    def forward(self, x, t):
        """
        We need to track gradients even when in eval mode, hence
        torch.enable_grad().
        """
        b, s, c = x.shape
        assert c == self.n_in, f"Expected {self.n_in} channels. {x.shape=}."
        # Move all parallel computation to batch dimension.
        x = einops.rearrange(x, "b s c -> (b s) c", b=b, s=s, c=c)
        t = einops.rearrange(t, "b s -> (b s) 1")
        # t, if loaded from a dataloader, will not have a Jacobian computed. As
        # we need to calculate the gradient of the cumulative hazard wrt t, we
        # need to ensure that requires_grad is set to True.
        t.requires_grad_(True)
        # Make sure that the input `t` is _not_ normalized, as we will be
        # calculating the gradient wrt `t`, and normalization will change the
        # scale of the dimension over which we sum.
        tt = self.query_norm(torch.clone(t))
        eps = 1e-10
        # not just training, as the last update step can cause negatives.
        # if self.training:
        self.clamp_weights()
        fc_t = self.fc_t(tt)
        x = self.fc_h(x)
        # Divide by 2 to maintain variance. Both are assumed to have ~the same.
        x = torch.tanh((fc_t + x) / 2.0)

        def mid(x):
            for m in self.mids:
                x = torch.tanh(m(x))
            return x

        x = mid(x)
        x = self.int_hazard_fc(x)
        # Let's choose just before the softplus as the m_out (used for logging).
        m_out = x
        int_h = F.softplus(x)
        int_h = einops.rearrange(int_h, "b 1 -> b")
        # Must be sum in order to maintain the magnitude of the gradient.
        # retain_graph defaults to the value of create_graph. Retain graph
        # must be True for us here, as we do not want the grad call to cause
        # parameters to free their gradients, which is normally done by grad().
        # So setting retain_graph=True isn't necessary, but it's explicit.
        h = torch.autograd.grad(
            int_h.sum(), t, create_graph=True, retain_graph=True
        )
        log_h = torch.log(h[0] + eps)
        assert torch.isfinite(log_h).all(), log_h
        assert torch.isfinite(int_h).all(), int_h
        log_h = einops.rearrange(log_h, "(b s) 1 -> b s", b=b, s=s)
        int_h = einops.rearrange(int_h, "(b s) -> b s", b=b, s=s)
        m_out = einops.rearrange(m_out, "(b s) 1 -> b s", b=b, s=s)
        return m_out, log_h, int_h


class GPT(nn.Module):

    def __init__(
        self,
        input_len,
        n_layer,
        n_head,
        head_dim,
        out_resolution=128,
        causal=True,
    ):
        super().__init__()
        self.input_len = input_len
        self.n_embd = n_head * head_dim
        # self.input_norm = kdai.nn.Normalize()
        self.idx_embed = nn.Embedding(self.input_len, self.n_embd)
        self.out_resolution = out_resolution
        self.causal = causal
        if self.causal:
            self.fc = nn.Linear(self.n_embd, self.out_resolution)
        else:
            self.fc = nn.Linear(
                self.n_embd * self.input_len, self.out_resolution
            )
        self.transformer = kdai.nn.Transformer(
            dim=self.n_embd,
            depth=n_layer,
            heads=n_head,
            mlp_dim=4 * self.n_embd,
            dropout=0.1,
            is_causal=self.causal,
        )
        self.register_buffer(
            "slope_cache",
            # 1 -> 1/10,000. Linearly in log space over n_h//2 steps.
            # Note that this differs slightly from the THP implementation
            # that uses [sin(0), cos(1), sine(2),...cos(N-1)], assuming N is
            # even. Here we keep to another common practice of [sin(0), sin(1),
            # ...sin(N//2), cos(0), cos(1),...cos(N//2)].
            torch.exp(
                -torch.linspace(
                    0, math.log(10000), self.n_embd // 2, dtype=torch.float32
                )
            ),
            persistent=False,
        )

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        with_grad = {
            n: p for n, p in self.named_parameters() if p.requires_grad
        }
        with_decay = [p for p in with_grad.values() if p.dim() > 1]
        no_decay = [p for p in with_grad.values() if p.dim() <= 1]
        return [with_decay, no_decay]

    @torch.no_grad()
    def to_t_embed(self, x):
        x = torch.cumsum(x, dim=1)
        # x, having last dimension 1 will broadcast with half_nch.
        t_embd = self.slope_cache * x
        t_embd = torch.cat([t_embd.sin(), t_embd.cos()], dim=-1).to(
            dtype=torch.float32
        )
        return t_embd

    def forward(self, x, mask):
        b, s, c = x.shape
        # (B, S, E) <- (B, S, E) + (S, E)
        x = self.to_t_embed(x) + self.idx_embed(torch.arange(s).to(x.device))
        # Mask is expected to be broadcastable to (B, e, target_len, source_len)
        # TODO: check we have the correct shape.
        # mask = einops.rearrange(mask, "b t -> b 1 1 t")
        x = self.transformer(x, mask)
        if not self.causal:
            x = einops.rearrange(x, "b t c -> b (t c)")
        x = self.fc(x)
        return x


class GPTv2(nn.Module):

    def __init__(
        self,
        input_len,
        n_layer,
        n_head,
        head_dim,
        out_resolution=128,
        causal=True,
    ):
        super().__init__()
        self.input_len = input_len
        self.n_embd = n_head * head_dim
        self.idx_embed = kdai.nn.IdxEmbed(self.input_len, self.n_embd)
        self.t_embed = kdai.nn.ValueEmbed(self.n_embd)
        self.out_resolution = out_resolution
        self.causal = causal
        if self.causal:
            self.fc = nn.Linear(self.n_embd, self.out_resolution)
        else:
            self.fc = nn.Linear(
                self.n_embd * self.input_len, self.out_resolution
            )
        self.transformer = kdai.nn.Transformer(
            dim=self.n_embd,
            depth=n_layer,
            heads=n_head,
            mlp_dim=4 * self.n_embd,
            dropout=0.1,
            is_causal=self.causal,
        )

    @torch.no_grad()
    def init_input_scale(self, t_max_range, t_epsilon):
        _logger.info(
            f"Initializing t_embed weights. {t_max_range=}, {t_epsilon=}"
        )
        self.t_embed.init_weights(t_max_range, t_epsilon)

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        res = kdai.nn.get_optim_param_groups(self)
        return res

    @torch.no_grad()
    def to_t_embed(self, x):
        x = torch.cumsum(x, dim=1)
        t_embd = self.t_embed(x)
        return t_embd

    def forward(self, x, mask):
        b, s, c = x.shape
        # (B, S, E) <- (B, S, E) + (S, E)
        x = self.to_t_embed(x) + self.idx_embed()
        # Mask is expected to be broadcastable to (B, e, target_len, source_len)
        # TODO: check we have the correct shape.
        # mask = einops.rearrange(mask, "b t -> b 1 1 t")
        x = self.transformer(x, mask)
        if not self.causal:
            x = einops.rearrange(x, "b t c -> b (t c)")
        x = self.fc(x)
        return x


class TransformerBase(nn.Module):
    """
    A shared base for the Spike data models.

    This transformer does _not_ use causal learning trick because:
      - there are too many data points (992). A 992 length input is very large
        to be kept for all transformer layers. Instead, it would be better to
        use a much smaller dimension for the transformer.
      - it makes sense to downsample the stimulus.
      - the distance array model would be very difficult to port to a
        many-output model, and would require a very large output, e.g. 992x128.

    The cell id is encoded using an embedding (n_cells = n_seqs). Using
    n_seqs rather than n_cells, as it seems to generalize fine to any dataset
    where we allow knowledge of a sequence's identity. This identity prevents
    generalization to new sequences, and only extends to other sequences
    generated from the same processes that generated each sequence, i.e. more
    spikes from the same cells. This could be relaxed by creating an embedding
    for each cell based on a longer sequence. Another perspective (a positive
    one) on the use of sequence ids is that we are simply assuming that a
    good embedding can be calculated for each sequence based on some subset of
    the sequence; so it isn't anything more than that assumption plus the
    shortcut of using the sequence id.
    """

    def __init__(self, input_len, n_layer, n_head, head_dim, n_seqs, expansion=4):
        super().__init__()
        self.input_len = input_len
        self.n_embd = n_head * head_dim
        self.enc_len = 64
        self.pos_embd = nn.Parameter(torch.zeros(1, self.enc_len, self.n_embd))
        n_seq_embd = 2**math.ceil(math.log2(n_seqs))
        self.seq_embd = nn.Embedding(n_seq_embd, self.n_embd)
        self.expansion = expansion

        # 1024 -> 512 -> 256 -> 128 -> 64
        n_downsample = 4
        self.input_n_c = 5  # assumed 4 stim, 1 spike
        # Unused.
        self.stim_embed = nn.Conv1d(self.input_n_c, self.n_embd, kernel_size=1)

        self.transformer = kdai.nn.Transformer(
            dim=self.n_embd,
            depth=n_layer,
            heads=n_head,
            mlp_dim=4 * self.n_embd,
            dropout=0.1,
            is_causal=False,
        )

        self.l0_n_c = 16
        k0_size = 21
        k1_size = 7
        self.l1_n_c = 64
        self.cnn = nn.Sequential(
            nn.Conv1d(
                self.input_n_c,
                self.l0_n_c,
                kernel_size=k0_size,
                stride=2,
                padding=(k0_size - 1) // 2,
                bias=True,
            ),
            nn.LeakyReLU(0.2, True),
            nn.Conv1d(
                self.l0_n_c,
                self.l0_n_c,
                kernel_size=k0_size,
                stride=1,
                padding=(k0_size - 1) // 2,
                bias=False,
            ),
            kdai.nn.create_batch_norm(self.l0_n_c),
            nn.LeakyReLU(0.2, True),
            *[
                kdai.nn.ResBlock1d(
                    self.l1_n_c if i else self.l0_n_c,
                    self.l1_n_c * self.expansion,
                    self.l1_n_c,
                    kernel_size=k1_size,
                    downsample=True,
                )
                for i in range(n_downsample - 1)
            ],
            kdai.nn.ResBlock1d(
                self.l1_n_c,
                self.l1_n_c * self.expansion,
                self.n_embd,
                kernel_size=k1_size,
                downsample=False,
            ),
        )

    @torch.no_grad()
    def init_weights(self):
        self.pos_embd.data = kdai.nn.sinusoidal_embedding(
            self.enc_len, self.n_embd
        )
        # TODO: conv1d layers

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        res = kdai.nn.get_optim_param_groups(self)
        return res

    def encode(self, x, seq_id):
        x = self.cnn(x)
        x = einops.rearrange(x, "b c t -> b t c")
        seq_embd = einops.rearrange(self.seq_embd(seq_id), "b e -> b 1 e")
        x = x + seq_embd
        x = x + self.pos_embd
        return x

    def forward(self, x, seq_id):
        b, s, c = x.shape
        x = self.encode(x, seq_id)
        x = self.transformer(x)
        # What to do with this? Currently, just getting the last output.
        # We could use a query token instead.
        x = x[:, -1]
        return x


class LogmixTf(nn.Module):
    def __init__(self, gpt_base, n_mix: int = 16):
        super().__init__()
        # Non-learnable input normalization.
        self.register_buffer("input_mean", torch.zeros(size=(5,)))
        self.register_buffer("input_sd", torch.ones(size=(5,)))
        self.gpt_base = gpt_base
        self.n_mix = n_mix
        param_per_mix = 3
        self.fc = nn.Linear(self.gpt_base.n_embd, n_mix * param_per_mix)
        # Learnable output scale and offset.
        self.out_offset = nn.Parameter(torch.tensor(0.0))
        self.out_scale = nn.Parameter(torch.tensor(1.0))

    @torch.no_grad()
    def set_input_mean_sd(self, m: torch.Tensor, sd: torch.Tensor):
        if m.shape != (5,):
            raise ValueError(
                "Input mean must have shape (5,). " f"Got ({m.shape})."
            )
        if sd.shape != (5,):
            raise ValueError(
                "Input sd must have shape (5,). " f"Got ({sd.shape})."
            )
        self.input_mean.copy_(m)
        self.input_sd.copy_(sd)

    @torch.no_grad()
    def init_output(self, mean_log, sd_log):
        """Set output layer's normalization parameters.

        LogMix has basic normalization:
          - out_mu => out_mu + mean    | making out_mu ~ 0
          - out_sigma => out_sigma + sd | making out_sigma ~ 0
        The output sigma is in log space, so we add sd to it to make it ~0.
        """
        _logger.info(
            "LogMix output initialization. " f"{mean_log=:.3g}, {sd_log=:.3g}"
        )
        # within no_grad(), tensors are still created with requires_grad=True.
        self.out_offset.fill_(mean_log)
        self.out_scale.fill_(sd_log)
        # For initial values, spread the mixtures out a bit, using bias?

    def normalize_input(self, x):
        return (x - self.input_mean[None, :, None]) / self.input_sd[
            None, :, None
        ]

    def optim_param_groups(self):
        res = kdai.nn.get_optim_param_groups(self)
        return res

    def forward(
        self, x, seq_id
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        b, c, t = x.shape
        x = self.normalize_input(x)
        x = self.gpt_base(x, seq_id)
        mix_params = self.fc(x)
        log_tau, mu, log_sigma = mix_params.chunk(3, dim=1)
        # Output normalization.
        mu = mu + self.out_offset
        log_sigma = log_sigma + self.out_scale
        return log_tau, mu, log_sigma


class DiscreteTf(nn.Module):
    def __init__(self, gpt_base, out_resolution=128):
        super().__init__()
        self.register_buffer("input_mean", torch.zeros(size=(5,)))
        self.register_buffer("input_sd", torch.ones(size=(5,)))
        self.gpt_base = gpt_base
        self.out_resolution = out_resolution
        self.fc = nn.Linear(self.gpt_base.n_embd, self.out_resolution)

    def set_input_mean_sd(self, m: torch.Tensor, sd: torch.Tensor):
        if m.shape != (5,):
            raise ValueError(
                "Input mean must have shape (5,). " f"Got ({m.shape})."
            )
        if sd.shape != (5,):
            raise ValueError(
                "Input sd must have shape (5,). " f"Got ({sd.shape})."
            )
        self.input_mean.copy_(m)
        self.input_sd.copy_(sd)

    def normalize_input(self, x):
        return (x - self.input_mean[None, :, None]) / self.input_sd[
            None, :, None
        ]

    def optim_param_groups(self):
        res = kdai.nn.get_optim_param_groups(self)
        return res

    def forward(self, x, seq_id) -> torch.Tensor:
        b, c, t = x.shape
        x = self.normalize_input(x)
        x = self.gpt_base(x, seq_id)
        x = self.fc(x)
        return x


class DistTf(nn.Module):
    OUT_LEN = 128

    def __init__(self, gpt_base):
        super().__init__()
        self.gpt_base = gpt_base
        self.fc = nn.Linear(self.gpt_base.n_embd, self.OUT_LEN)
        self.register_buffer("input_mean", torch.zeros(size=(5,)))
        self.register_buffer("input_sd", torch.ones(size=(5,)))
        self.register_buffer("output_mean", torch.tensor(1.0))
        self.output_scale = 2.0

    def set_input_mean_sd(self, m: torch.Tensor, sd: torch.Tensor):
        if m.shape != (5,):
            raise ValueError(
                "Input mean must have shape (5,). " f"Got ({m.shape})."
            )
        if sd.shape != (5,):
            raise ValueError(
                "Input sd must have shape (5,). " f"Got ({sd.shape})."
            )
        self.input_mean.copy_(m)
        self.input_sd.copy_(sd)

    def set_output_mean(self, m: float):
        """Set the value by which the log output dist is added to."""
        self.output_mean.fill_(m)

    def denormalize_output(self, x):
        return (x * self.output_scale) + self.output_mean

    def normalize_input(self, x):
        return (x - self.input_mean[None, :, None]) / self.input_sd[
            None, :, None
        ]

    def optim_param_groups(self):
        res = kdai.nn.get_optim_param_groups(self)
        return res

    def forward(self, x, seq_id):
        # def forward(self, x, seq_id):
        b, s, m = x.shape
        x = self.normalize_input(x)
        x = self.gpt_base(x, seq_id)
        x = self.fc(x)
        return x


class ZuoAttention(nn.Module):
    """
    A separate class from kdai.nn.Attention as Zuo et al. do not keep the combined head
    sizes equal to the input size. This code is copied from Yangalan123's implementation:
        https://github.com/yangalan123/anhp-andtt/blob/master/thp/thp_training/transformer/SubLayers.py
    A few changes were made to simplify the code.
    No major functional changes except for the normalization being applied to all q-k-v matrix inputs, and not
    just the query input, which was done in the original implementation, but which is not standard.
    """

    def __init__(self, n_head, d_model, d_qkv, dropout=0.1):
        super().__init__()
        self.n_head = n_head
        self.to_qkv = nn.Linear(d_model, 3 * n_head * d_qkv, bias=False)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
        self.fc = nn.Linear(d_qkv * n_head, d_model)
        nn.init.xavier_uniform_(self.to_qkv.weight)
        nn.init.xavier_uniform_(self.fc.weight)
        self.dropout_p = dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        x = self.layer_norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(
            lambda t: einops.rearrange(
                t, "b t (h c) -> b h t c", h=self.n_head
            ),
            qkv,
        )
        mask = einops.rearrange(mask, "b t -> b t 1")
        # Note: Both is_causal=True and mask cannot be given at the same time.
        # We could do the padding mask along with the attention mask, but THP
        # implementation does not do this, they multiply the padding mask after.
        dropout_p = self.dropout_p if self.training else 0.0
        output = F.scaled_dot_product_attention(
            q, k, v, dropout_p=dropout_p, is_causal=True
        )
        output = einops.rearrange(output, "b h t d -> b t (h d)")
        output = output * mask
        # Zuo uses two dropouts.
        output = self.dropout(self.fc(output))
        return output


class ZuoTHP(nn.Module):
    """
    Paper's default model seems to be:

        4 layers, 4 heads, 512/4 = 128 head-dim.
    Interesting deviations from a standard gpt like transformer:
        - the temporal embedding is re-added to the output of each layer. It's
          not clear why this is done.
        - there are two residual additions per layer, once after attention and
          once after the MLP, as opposed to once per layer.
        - they implement a mask to cover padding that can prefix model inputs.
          Here, we don't use this, as it seems easier to just use the
          earlier outputs of the transformer that are normally only used for
          training. These earlier outputs already have an effectively
          shortened input length.
        - mlp inner dimension is x2 rather than x4.
    """

    # A GPT style transformer would have:
    #    n_head * head_dim = d_embd
    # and the mlp_dim would be 4 * d_embd.
    # Zuo's transformer's attention layer significantly expands d_embd then projects back down again.
    # These parameters come from Zuo et al.'s supplemental. They are called
    # parameter set 1, 2, and 3, but we will index from 0.
    PARAM_SET_0 = {
        "n_layer": 3,
        "n_head": 3,
        # The main bus size. Positional/type embedding use this size. Input and output dim of transformer.
        "d_embd": 64,
        "d_qkv": 16,
        # The inner expanded dimension of the MLP.
        "d_mlp": 256,
    }
    PARAM_SET_1 = {
        "n_layer": 6,
        "n_head": 6,
        "d_embd": 128,
        "d_qkv": 64,
        "d_mlp": 2048,
    }
    PARAM_SET_2 = {
        "n_layer": 4,
        "n_head": 4,
        "d_embd": 512,
        "d_qkv": 512,
        "d_mlp": 1024,
    }

    @classmethod
    def from_param_set(cls, param_set_idx):
        param_set = [
            cls.PARAM_SET_0,
            cls.PARAM_SET_1,
            cls.PARAM_SET_2,
        ][param_set_idx]
        return ZuoTHP(**param_set)

    def __init__(
        self,
        n_layer,
        n_head,
        d_embd,
        d_qkv,
        d_mlp,
        dropout=0.1,
    ):
        super().__init__()
        self.d_embd = d_embd
        # Parameters a and b from:
        #   λ(t) = softplus( a*(t-t_j)/t_j + w^T h + b )
        self.input_norm = kdai.nn.Normalize()
        # C. Yang and Zuo implementations both initializes α = -0.1 and β = 1.0.
        self.alpha = nn.Parameter(torch.tensor(-0.1))
        self.beta = nn.Parameter(torch.tensor(1.0))
        self.layers = nn.ModuleList(
            [
                nn.ModuleDict(
                    {
                        "attn": ZuoAttention(
                            n_head,
                            self.d_embd,
                            d_qkv=d_qkv,
                            dropout=dropout,
                        ),
                        "mlp": kdai.nn.FeedForward(
                            self.d_embd, d_mlp, dropout=dropout
                        ),
                    }
                )
                for _ in range(n_layer)
            ]
        )
        # It seems like Zuo's implementation uses a learnable positional
        # vector:
        self.position_vec = torch.tensor(
            [
                math.pow(10000.0, 2.0 * (i // 2) / self.d_embd)
                for i in range(self.d_embd)
            ],
            device=torch.device("cuda"),
        )

        # self.register_buffer(
        #     "slope_cache",
        #     # 1 -> 1/10,000. Linearly in log space over n_h steps.
        #     torch.exp(
        #         -torch.linspace(
        #             0, math.log(10000), self.d_embd, dtype=torch.float32
        #         )
        #     ),
        #     persistent=False,
        # )

        # Just 1 parameter output for probability.
        # This is the `linear` parameter in the C. Yang implementation.
        self.fc_h = nn.Linear(self.d_embd, 1)
        # Just 1 scaler output for time.
        # This is the `time_predictor` module in the C. Yang implementation.
        # There is no bias (same as in the paper), which seems strange to me.
        # self.fc_t = nn.Linear(self.d_embd, 1, bias=False)
        # We will add a bias, as otherwise, there is no way to offset the
        # prediction, and it will negatively affect the shared enc.
        self.fc_t = nn.Linear(self.d_embd, 1, bias=True)
        self.init_weights()
        # Tags. Should be no-ops
        self.x_tag = kdai.train.TensorTag("0 x")
        self.enc_tag = kdai.train.TensorTag("2 enc")
        self.last_mid = kdai.train.TensorTag("4 last_mid")
        self.t_tag = kdai.train.TensorTag("5.1 t")
        self.h_tag = kdai.train.TensorTag("5.2 h")
        self.t_denorm_tag = kdai.train.TensorTag("6 t_denorm")
        self.pos_vec_tag = kdai.train.TensorTag("0 pos_vec")
        self.pos_vec_tag_out = kdai.train.TensorTag("1 pos_vec_out")

    def init_weights(self):
        # Again, copying C. Yang's initialization.
        nn.init.xavier_normal_(self.fc_t.weight)

    def optim_param_groups(self):
        """Return parameters with and without decay, in separate groups.

        This function signature is designated by train.py.
        """
        with_grad = {
            n: p for n, p in self.named_parameters() if p.requires_grad
        }
        with_decay = [p for p in with_grad.values() if p.dim() > 1]
        no_decay = [p for p in with_grad.values() if p.dim() <= 1]
        return [with_decay, no_decay]

    @torch.no_grad()
    def temporal_enc(self, x, mask):
        mask = einops.rearrange(mask, "b t -> b t 1")
        x = torch.cumsum(x, dim=1)
        # x, having last dimension 1 will broadcast nch.
        self.pos_vec_tag(self.position_vec)
        t_embd = x / self.position_vec
        self.pos_vec_tag_out(t_embd)
        t_embd[:, :, 0::2] = torch.sin(t_embd[:, :, 0::2])
        t_embd[:, :, 1::2] = torch.cos(t_embd[:, :, 1::2])
        return t_embd * mask

    def forward(self, x, mask):
        b, s, c = x.shape
        # (B, S, E) <- (B, S, E) + (S, E)
        # Zuo, Yang and EasyTPP's implementation don't have normalization.
        # Rather, they hard-code a normalization per-dataset. Here, we let
        # each model do their own normalization.
        x = self.x_tag(x)
        # Apart from normalization, the implementation matches: https://github.com/yangalan123/anhp-andtt/blob/bf8708631cf9ab67a0b682a5a784025b75d0602c/thp/thp_training/transformer/Models.py#L85
        # but without doing causal attention masking manually, and without
        # the pad mask, which we don't use.
        t_enc = self.temporal_enc(x, mask)
        t_enc = self.enc_tag(t_enc)
        x = t_enc.clone()
        for l in self.layers:
            x = x + t_enc
            x = l["attn"](x, mask) + x
            x = l["mlp"](x) + x
        x = self.last_mid(x)

        h = self.fc_h(x)
        h = self.h_tag(h)
        t = self.fc_t(x)
        t = self.t_tag(t)
        return h, t
