import bisect
import logging
import numpy as np
import einops
import scipy.stats
import torch
import torch.utils.data
import torch.nn.functional as F
from itertools import pairwise
import kdai
import kdai.datasets
import kdai.train
from typing import Optional, Tuple, Sequence
from numpy.typing import ArrayLike
from numpy.lib.stride_tricks import sliding_window_view

_logger = logging.getLogger(__name__)

DEFAULT_SEED = 1
BLANK_TOKEN = -1


def estimate_relevant_scales(train_dt_seqs, input_len):
    """Estimate the largest and smallest time deltas that should be considered.

    This is a relatively novel concept unique to time series data.
    When converting time deltas to transformer inputs, it is best to make an
    informed decision about the range of time deltas that should be considered.
    What makes this interesting is that the largest range the transformer
    can consider is a function of the input length: say the input is 64
    time deltas, then if we calculate cumsum(time_deltas), the largest value
    will be the oldest event, 64 events ago; take in more events and the
    largest value will increase.

    The minimum value is simpler, as the input length doesn't affect it.
    There are still nuances to consider. Here, we will just take the minimum
    relevant time delta to be the minimum time delta in the training data.
    One might argue that this is not small enough and prevents generalization
    to smaller time deltas. This issues is ameliorated by how the min_dt is
    used: we expect a model to set the minimum time delta to land near but not
    at the threshold minimum time delta the model can discern.

    Update: actually, absolute values are not so important. It's more about
    the largest and smallest differences.
    """

    # Use the training data minimum directly.
    def _pad(seq):
        if len(seq) < input_len:
            seq = np.pad(seq, (input_len - len(seq), 0), mode="edge")
        return seq

    min_dt = min([np.min(dt_seq) for dt_seq in train_dt_seqs])

    max_dt = max(
        [
            np.max(np.sum(w, axis=1))
            for w in [
                sliding_window_view(_pad(dt_seq), input_len, writeable=False)
                for dt_seq in train_dt_seqs
            ]
        ]
    )
    return min_dt, max_dt


def epoch_opts(
    train_len, total_samples, max_epochs, max_batch_size, min_steps_per_epoch
):
    """Calculate number of epochs and batch size used for training."""
    batch_size = max(1, min(train_len // min_steps_per_epoch, max_batch_size))
    if total_samples % train_len != 0:
        raise ValueError(f"{total_samples=} must be divisible by {train_len=}.")
    n_epochs = min(max_epochs, total_samples // train_len)
    return n_epochs, batch_size


def split_and_slide(
    t_seq: ArrayLike, t_max: int, ratio: ArrayLike, input_len: int
):
    """
    Cuts input event sequence into sliding-window matrices, one for each ratio.
    """
    t_seq = np.array(t_seq)
    assert len(t_seq.shape) == 1, "Marks are not yet supported."
    t_seqs = kdai.datasets.split_seq1d(t_seq, ratio, t_max)
    dt_seqs = [np.diff(t) for t in t_seqs]
    dt_window_matrices = [
        sliding_window_view(dt_seq, input_len + 1, writeable=False)
        for dt_seq in dt_seqs
    ]
    return dt_window_matrices


def trim_lens(lens, n_events):
    """Generalization of trim_to(), where we only work with lengths.

    The actually slicing is left up to the caller. This method is useful when
    the objects being sliced are not simple arrays, but more complex objects
    like torch datasets, or tuples of arrays.
    """
    if n_events <= 0:
        raise ValueError(f"n_events must be greater than 0. ({n_events=})")
    if any(l <= 0 for l in lens):
        raise ValueError(f"Lengths must be greater than 0. ({lens=})")
    cumsum = np.cumsum(lens)
    if n_events > cumsum[-1]:
        raise ValueError(
            f"n_events ({n_events}) shouldn't be more than the "
            f" total number of events ({cumsum[-1]})."
        )
    # cumsum[last_seq_idx - 1] < n_events <= cumsum[last_seq_idx]
    # (last_seq_idx-1) would be too few, and last_seq_idx is sufficient.
    last_seq_idx = np.searchsorted(cumsum, n_events)
    sum_to_prev = cumsum[last_seq_idx - 1] if last_seq_idx > 0 else 0
    trim_last_to = n_events - sum_to_prev
    assert trim_last_to > 0, "Previous idx should have been sufficient."
    assert trim_last_to <= lens[last_seq_idx], "idx is insufficient."
    return last_seq_idx, trim_last_to


def trim_to(seqs, n_events):
    """Trims sequences to the given number of events.

    You probably want to be sending in time-delta sequences.

    Note: kdai has a slice_to_cum_len() function that could replace this
    function.

    Example:
    With a target of 1024 events, a given list of 12 sequences might be trimmed
    to 5 sequences, and the 5th sequence might have it's events trimmed to say
    57 events in order to get the total number of events to 1024.

    Out of the returned sequences, only the last sequence has the possibility of
    being trimmed.
    """
    lens = [len(s) for s in seqs]
    last_seq_idx, trim_last_to = trim_lens(lens, n_events)
    trimmed_seqs = seqs[: last_seq_idx + 1]
    trimmed_seqs[-1] = trimmed_seqs[-1][:trim_last_to]
    # Double-check.
    n_res = sum(len(s) for s in trimmed_seqs)
    assert n_res == n_events, f"{n_res=} == {n_events=}"
    return trimmed_seqs


class SeqListDataset(torch.utils.data.Dataset):
    """
    Dataset of variable length sequences.

    Sequences are pre-padded with BLANK_TOKEN to reach model_in_len+1 length.
    Pre-padding is done for a number of reasons:
        - pre-padding allows the last element y to always be the target value.
          This avoids needing to pass around masks or lengths.
        - pre-padding is not detrimental to RNN performance.
        - every sample is expected to predict just 1 event. This can slow
          training but makes the interface consistent and easy to work with.

    Some characteristics:
      - each sample is (x, y), where x and y have the same length,
        and come from the same sequence but y is shifted by 1 into the future.
      - x and y always have length model_in_len.
    """

    def __init__(self, seqs, model_in_len, pad_val=BLANK_TOKEN):
        self.model_in_len = model_in_len
        self.pad_val = pad_val
        self.min_real_input = 1
        self.cumulative_starts = self._cumsum(seqs, self.min_real_input)
        self.seqs = self._pad(
            seqs, self.model_in_len, self.min_real_input, pad_val
        )

    @staticmethod
    def _pad(seqs, model_in_len, min_in_len, pad_val):
        res = []
        pad_len = model_in_len - min_in_len
        for s in seqs:
            s = torch.tensor(s, dtype=torch.float32)
            if torch.any(s == pad_val):
                raise ValueError(f"Sequences must not contain {pad_val=}.")
            res.append(F.pad(s, (pad_len, 0), mode="constant", value=pad_val))
        return res

    @staticmethod
    def _cumsum(seqs, min_in_len):
        """
        Example:
            cumsum([
                [0, 1, 2],
                [0, 1, 2, 3],
                [0, 1, 2, 3, 4],
            ], model_in_len=2) == [0, 1, 3, 6]
        """
        r = [0]
        for s in seqs:
            n_samples = len(s) - min_in_len
            r.append(r[-1] + n_samples)
        return r

    def to_seq_idx(self, idx):
        # right and left bisects only differ when the requested value equals
        # one of the values in the list.
        #   - bisect_left([0, 5], 0) returns 0
        #   - bisect_right([0, 5], 0) returns 1.
        # Don't try the following line:
        ## seq_idx = bisect.bisect_left(self.cumulative_starts, idx)
        # Why? Example case: [0, 5, 12], idx=4 => 1 (correct is 0)
        seq_idx = bisect.bisect_right(self.cumulative_starts, idx) - 1
        sample_idx = idx - self.cumulative_starts[seq_idx]
        return seq_idx, sample_idx

    def __len__(self):
        return self.cumulative_starts[-1]

    def __getitem__(self, idx):
        """
        Returns:
            (x, mask, y)
                x: torch tensor of shape (1, model_in_len)
                mask: torch tensor of shape (model_in_len, )
                y: torch tensor of shape (1, model_in_len)
        """
        if idx < 0:
            # Support negative indexing. Note: -len(self) is allowed.
            if abs(idx) > len(self):
                raise ValueError(f"{idx=} out of range. {len(self)=}.")
            idx = len(self) + idx
        if idx >= len(self):
            raise ValueError(f"{idx=} out of range. {len(self)=}.")
        seq_idx, sample_idx = self.to_seq_idx(idx)
        return self.get(seq_idx, sample_idx)

    def get(self, seq_idx, sample_idx):
        sl = slice(sample_idx, sample_idx + self.model_in_len + 1)
        sample = self.seqs[seq_idx][sl]
        assert len(sample) == self.model_in_len + 1, f"{len(sample)=}"
        # Split the sample into x and y.
        # While it looks wastful here, the interface needs to allow datasets
        # to augment/transform the X and Y in different ways, so the consumers
        # shouldn't expect a single tensor.
        x, y = sample[:-1], sample[1:]
        mask_in = x != self.pad_val
        x = einops.rearrange(x, "s -> s 1")
        y = einops.rearrange(y, "s -> s 1")
        return x, mask_in, y


class EventSeqListDatasets(kdai.train.DatasetManager):
    """Manager for a list of variable length event sequences.

    Train time considerations
    -------------------------
    The overall goal is to have a training loop that doesn't require an
    if-else block to determine how to calculate the loss. Ideally this would
    extend to the evaluation of metrics as well, but in reality, the beginning
    of a sequence will need to be treated differently from the rest of the
    sequence.

    Which strategy to use will depend on whether A, B, and C are true or false:

    A. Are you using "causal trick"? For an x = (batch, time, channels) tensor,
        do all models produce a (batch, time) tensor output?
    B. Is the model input length always greater-equal to the sequence length?
    C. Is the model input length always less-equal to the sequence length?

    Decision tree:
    - 1. A ∧ B = True.
        Every sample will need padding, and every forward pass will need to
        mask over the padded elements to calculate the loss. For evaluation,
        the same masking procedure can be used, as it is known that every
        timestep should contribute to a metric.

        This is probably an unlikely scenario, as you would have to commit to
        a minimum input length, and commit to only accepting sequences that are
        shorter than the model input length.

    - 2. A ∧ C = True.
        No padding is needed, as every value in the sequence is an actual
        datapoint. Evaluation departs from training, as the first sample of a
        sequence should have a prediction for every timestep, whereas all other
        samples (that are not flush with the beginning of the sequence) should
        only predict the next timestep. This difference can optionally be
        ignored for logging metrics while training, but for evaluations for
        reports, the difference should be respected and care should be taken to
        treat the first sample of a sequence differently from the others.

        This scenario is the case for the cyclic dataset: all sequences are of
        length 1024, and we are happy to assert than all model input lengths
        will be less than 1024—indeed, what would be the point, given that
        there exists no sequences greater than 1024?

        A sub-case of this scenario is when the sequence length is so long that
        it won't affect metric calculations if we ignore the initial part of
        a sequence, up to the longest model input length (among all models). In
        this case, when evaluating, we can treat all samples as contributing
        only a single prediction to a metric. This is the case for the RandProc
        datasets, which have very long _evaluation_ sequences.


    - 3. A = True, but B ∨ C = False.
        Some samples need padding, and some don't. Those that do correspond to
        samples flush with the beginning of a sequence that doesn't have enough
        timesteps to fill a model's input. For training, padding and masking
        is needed, just like in scenario 1. Unlike scenario 1, proper evaluation
        requires extra care, just like in scenario 2: the first sample of a
        sequence should contribute every timestep to the metric, whereas all
        other samples should only predict the next timestep.

        This scenario is the case for the classic datasets, such as nyc-taxi,
        and so-badges; they have variable length sequences, some of which are
        long enough to fill a model's input, and some are too short (even as
        short at just 2 events, which makes only 1 time-delta).

    A general interface
    -------------------
    If trainables are expected to work with various datasets where at least
    one dataset requires padding/masking, then a general interface should
    include a mask alongside the input. For datasets that don't require
    padding/masking, the mask can be all ones.


    Right or left pad
    -----------------
    Right padding is more common. However, left padding has the benefit of
    making the last element of the sequence the target value. This is useful
    when both causal and non-causal models are to use the same dataset. The
    non-causal models can continue to treat the last element as the target
    value. If right padding was used, then the non-causal models would have to
    identify the last non-padded element.



    Old notes:
    There are two very distinct ways of treating input-output:

    1. Predict every event
    ----------------------
    Every event in a sequence is predicted. This means that elements at the
    beginning of a sequence should be predicted, even if they are closer to the
    beginning than the model input length. This is not too hard to implement
    with masking if the maximum length sequence isn't too long. If the maximum
    sequence length is long, then a single sequence is split across multiple
    samples, but there is a difference in behaviour depending on whether a
    sub-sequence is the initial part of the sequence, or a part that is at
    least model_in_len from the beginning.

    If sequences are long, then most forward passes will just predict a
    single event per sample. If sequences are short, then most or all
    forward passes will be predicting multiple events per sample.

    What confuses this further is that some models may wish to predict
    all events in a sequence for training purposes (e.g. the transformer
    training trick), but only predict the next event for evaluation.

    A convention can be introduced that eases the implementation of these
    details:
      - the dataset only ever produces a middle part of the sequence or the
        largest sub-sequence that touches the beginning.
      - the dataset always right pads sequences to the maximum length in the
        batch. This is accompanied by a length value or a mask. Only the
        sequences that touch the beginning will have padding applied, and
        only sequences that are shorter than the model's input length will
        ever have padding applied.
      - for sequences touching the beginning, there is an equal weight placed
        on every event and all must be predicted. This effects both the loss
        and evaluation. For middle sequences, how the loss is effected is
        optional, but evaluation should only consider the prediction for the
        last element.

    2. Don't predict the initial part of the sequence
    -------------------------------------------------
    This is very easy to handle, and it doesn't mix up the predictions that
    vary in how much context is available. Hopefully, this case can be handled
    by the above pipeline, where consumers just ignore the mask or length, or
    the dataset has an option to drop it.
    """

    def __init__(
        self,
        train_dt_seqs: Sequence,
        val_dt_seqs: Sequence,
        test_dt_seqs: Sequence,
        model_in_len: int,
        start_idx: int = 0,
        pad_val: int = BLANK_TOKEN,
        density_interval_len: int = 1,
    ):
        """
        Args:
            train_dt_seqs: list of event sequences.
            val_dt_seqs: list of event sequences.
            test_dt_seqs: list of event sequences.
            model_in_len: all sequences will be padded to this length.
            start_idx: the index of the first event to predict. If this is
                greater than model_in_len, then there will be no padding.
        """
        self.n_train_seqs = len(train_dt_seqs)
        self.train_dt_seqs = train_dt_seqs
        self.model_in_len = model_in_len
        self.start_idx = start_idx
        self.pad_val = pad_val
        self.density_interval_len = density_interval_len

        self.dt_range_min, self.dt_range_max = estimate_relevant_scales(
            train_dt_seqs, model_in_len
        )
        # Don't need the padded sequences for the following stats.
        train_flat = np.concatenate(train_dt_seqs)
        self.dt_mean = np.mean(train_flat)
        self.dt_sd = np.std(train_flat)
        self.dt_min = np.min(train_flat)
        self.dt_max = np.max(train_flat)
        eps = 1e-10
        log_diff_t = np.log(train_flat + eps)
        self.log_dt_mean = np.mean(log_diff_t)
        self.log_dt_sd = np.std(log_diff_t)

        self._train_ds = SeqListDataset(
            train_dt_seqs, self.model_in_len, self.pad_val
        )
        self._val_ds = SeqListDataset(
            val_dt_seqs, self.model_in_len, self.pad_val
        )
        self._test_ds = SeqListDataset(
            test_dt_seqs, self.model_in_len, self.pad_val
        )

    def train_dts_flat(self):
        """Return all training dts as a single array.

        This was added to support calculation of quantiles.
        """
        return np.concatenate(self.train_dt_seqs)

    def train_ds(self) -> torch.utils.data.Dataset:
        return self._train_ds

    def val_ds(self) -> torch.utils.data.Dataset:
        return self._val_ds

    def test_ds(self) -> torch.utils.data.Dataset:
        return self._test_ds

    def val_dl(self) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(
            self.val_ds(), batch_size=1, shuffle=False
        )


class EventSeqArrDatasets(kdai.train.DatasetManager):
    """Manager for a 2D array: individual event sequences are rows.

    Being an array forces all sequences within a split to be the same length.

    The consumption of this dataset is the same as EventSeqDatasets, but the
    construction is more laborious.

    Instead of taking in 3 sequences (train, val, test) that were probably cut
    from a single sequence, we instead take in (train, val, test) lists
    of sequences. All sequences are the same length within these lists.

    Output data format:
        x: [N, t, c]  (num_seqs, events, channels)
        mask: [N, t]   (num_seqs, events)
        y: [N, t, c]     (num_seqs, channels)

    "timestamp" of the event is the 0th channel in the data. Marks, if present,
    are the remaining channels.
    """

    def __init__(
        self,
        train_seqs: Sequence,
        val_seqs: Sequence,
        test_seqs: Sequence,
        model_in_len: int,
        prepend_blanks=(False, False, False),
        synced_start: Optional[int] = None,
        density_interval_len: int = 1,
    ):
        """
        Args:
            input_len: the number of events given to a forward call.
            prepend_blanks: whether to prepend `model_in_len` BLANK tokens to
                the train, val and test sequences. Do this to the training
                split if you want to train on sequence beginnings and keep the
                split length independent of model input length. Do this to
                the val and test splits if you want models to be evaluated on
                exactly the same input-output pairs regardless of the model
                input lengths. Maybe don't do this if your evaluation sequences
                are long enough that the differences caused by different model
                lengths are negligible. Also consider using `synced_start`.
                For models that use the "causal trick", this should probably
                be (False, False, False).
            synced_start: if not None, the start of the sequence to position
                the model's output. This is useful when different models
                with different input lengths are being compared. This insures
                that inputs are to have prepended BLANKS and that all models
                will be evaluated on the same output sequence.
        """
        self.density_interval_len = density_interval_len
        self.n_train_seqs = len(train_seqs)

        def add_blank_prefix(seq):
            return np.concatenate([[BLANK_TOKEN] * model_in_len, seq])

        self.train_seqs, self.val_seqs, self.test_seqs = [
            [
                sliding_window_view(
                    add_blank_prefix(s) if do_prefix else s,
                    model_in_len + 1,
                    writeable=False,
                )
                for s in seqs
            ]
            for seqs, do_prefix in zip(
                (train_seqs, val_seqs, test_seqs), prepend_blanks
            )
        ]
        # Remove the prefix from the output sequences.
        if synced_start is not None:
            start_idx = synced_start - model_in_len
            if start_idx < 0:
                raise ValueError(
                    "synced_start must be greater-equal to the model input "
                    "length. Got (synced_start, model_in_len) = "
                    f"({synced_start}, {model_in_len})"
                )
            self.val_seqs = [s[start_idx:] for s in self.val_seqs]
            self.test_seqs = [s[start_idx:] for s in self.test_seqs]

        self.dt_range_min, self.dt_range_max = estimate_relevant_scales(
            train_seqs, model_in_len
        )
        # Don't need the padded sequences for the following stats.
        train_flat = np.concatenate(train_seqs)
        # Note that mean & sd don't account for the BLANK tokens. And they are
        # only calculated over the training data.
        self.dt_mean = np.mean(train_flat)
        self.dt_sd = np.std(train_flat)
        self.dt_min = np.min(train_flat)
        self.dt_max = np.max(train_flat)
        _diff_sorted = np.diff(np.sort(train_flat))
        self.ddt_min = np.min(_diff_sorted)
        self.ddt_max = np.max(_diff_sorted)
        eps = 1e-10
        log_diff_t = np.log(train_flat + eps)
        self.log_dt_mean = np.mean(log_diff_t)
        self.log_dt_sd = np.std(log_diff_t)

    @staticmethod
    def to_ds(data):
        return torch.utils.data.ConcatDataset(
            [
                torch.utils.data.TensorDataset(
                    # x
                    einops.rearrange(
                        torch.tensor(seq[:, :-1], dtype=torch.float32),
                        "n s -> n s 1",
                    ),
                    # mask
                    torch.tensor(
                        seq[:, :-1] != BLANK_TOKEN, dtype=torch.float32
                    ),
                    # y
                    einops.rearrange(
                        torch.tensor(seq[:, 1:], dtype=torch.float32),
                        "n s -> n s 1",
                    ),
                )
                for seq in data
            ]
        )

    def train_ds(self) -> torch.utils.data.Dataset:
        return self.to_ds(self.train_seqs)

    def val_ds(self) -> torch.utils.data.Dataset:
        return self.to_ds(self.val_seqs)

    def test_ds(self) -> torch.utils.data.Dataset:
        return self.to_ds(self.test_seqs)

    def val_dl(self) -> torch.utils.data.DataLoader:
        return torch.utils.data.DataLoader(
            self.val_ds(), batch_size=1, shuffle=False
        )


class EventSeqDatasets(kdai.train.DatasetManager):
    """A manager for the most spartan dataset type, just event times.

    Data format:
        x: [N, t, c]  (num_seqs, events, channels)
        y: [N, c]     (num_seqs, channels)

    "timestamp" of the event is the 0th channel in the data. Marks, if present,
    are the remaining channels.
    """

    def __init__(
        self,
        train_dt: np.ndarray,
        val_dt: np.ndarray,
        test_dt: np.ndarray,
        model_in_len: int,
        train_repeats: int = 1,
    ):
        """
        Args:
            input_len: the number of events given to a forward call.
            train_repeats: the number of times to repeat the training data.
                This is used to keep the number of steps in the training loop
                the same across different datasets.
        """
        self.train_repeats = train_repeats

        def add_blank_prefix(seq):
            return np.concatenate([[BLANK_TOKEN] * model_in_len, seq])

        self.train_data, self.val_data, self.test_data = [
            sliding_window_view(
                add_blank_prefix(t), model_in_len + 1, writeable=False
            )
            for t in (train_dt, val_dt, test_dt)
        ]
        # Note that mean & sd don't account for the BLANK tokens. And they are
        # only calculated over the training data.
        self.dt_mean = np.mean(train_dt)
        self.dt_sd = np.std(train_dt)
        eps = 1e-10
        log_diff_t = np.log(train_dt + eps)
        self.log_dt_mean = np.mean(log_diff_t)
        self.log_dt_sd = np.std(log_diff_t)

    @staticmethod
    def to_ds(data):
        return torch.utils.data.TensorDataset(
            einops.rearrange(
                torch.tensor(data[:, :-1], dtype=torch.float32), "n s -> n s 1"
            ),
            einops.rearrange(
                torch.tensor(data[:, -1], dtype=torch.float32), "n -> n 1"
            ),
        )

    def train_ds(self) -> torch.utils.data.Dataset:
        return kdai.datasets.LongerDataset.repeat(
            self.to_ds(self.train_data), self.train_repeats
        )

    def val_ds(self) -> torch.utils.data.Dataset:
        return self.to_ds(self.val_data)

    def test_ds(self) -> torch.utils.data.Dataset:
        return self.to_ds(self.test_data)


class RandProcessDatasets(kdai.train.DatasetManager):
    """Dataset manager for a sequence generated by a random process.

    We have a separate dataset manager for these sequences for two reasons:
        - the log probabilities of the event times are also stored.
        - the sequences are very long, so it is reasonable to ignore
          predictions without a full-length context, which would have been the
          very few predictions at the beginning of the sequence. This allows
          us to work with 2D arrays rather than lists of sequences. This means
          we don't need a custom dataset, we can just use standard PyTorch
          tensors with a windowed view.

    The log probabilities are used to calculate comparative log-likelihoods:
    compared to the mean log-likelihood (i.e. entropy) of the generating
    process, which acts as an upper bound on model mean log-likelihoods.
    """

    INTERVAL_FRACTION_OF_MEAN = 0.01

    def __init__(
        self,
        dt_seq,
        log_prob_seq,
        model_in_len,
        ratio=(7, 2, 1),
        train_repeats=1,
        prepend_blanks=(False, False, False),
        synced_start: Optional[int] = None,
    ):
        """
        Args:
            dt_seq: 1D np.ndarray of event time deltas.
            log_prob_seq: 1D np.ndarray of the log probabilities of the event
                times.
            input_len: the number of events given to a forward call.
            prepend_blanks: whether to prepend `model_in_len` BLANK tokens to
                the train, val and test sequences. Do this to the training
                split if you want to train on sequence beginnings and keep the
                split length independent of model input length. Do this to
                the val and test splits if you want models to be evaluated on
                exactly the same input-output pairs regardless of the model
                input lengths. Maybe don't do this if your evaluation sequences
                are long enough that the differences caused by different model
                lengths are negligible. Also consider using `synced_start`.
            synced_start: if not None, the start of the sequence to position
                the model's output. This is useful when different models
                with different input lengths are being compared. This insures
                that inputs are to have prepended BLANKS and that all models
                will be evaluated on the same output sequence.
        """
        assert len(dt_seq) == len(log_prob_seq)
        if len(ratio) != 3:
            raise ValueError("train/val/test ratio must have length 3.")
        self.train_repeats = train_repeats
        if dt_seq.min() <= 0:
            raise ValueError("Event times must be strictly increasing.")
        split_borders = kdai.datasets.split_borders(ratio, len(dt_seq))
        train_dts, val_dts, test_dts = np.split(dt_seq, split_borders)
        # Need to keep these around, as they are queried by trainables.
        self.train_dts = train_dts

        def add_blank_prefix(seq):
            return np.concatenate([[BLANK_TOKEN] * model_in_len, seq])

        self.train_data, self.val_data, self.test_data = [
            sliding_window_view(
                add_blank_prefix(x) if do_prefix else x,
                model_in_len + 1,
                writeable=False,
            )
            for x, do_prefix in zip(
                (train_dts, val_dts, test_dts), prepend_blanks
            )
        ]
        # Remove the prefix from the output sequences.
        if synced_start is not None:
            start_idx = synced_start - model_in_len
            if start_idx < 0:
                raise ValueError(
                    "synced_start must be greater-equal to the model input "
                    "length. Got (synced_start, model_in_len) = "
                    f"({synced_start}, {model_in_len})"
                )
            self.val_data = self.val_data[start_idx:]
            self.train_data = self.train_data[start_idx:]
            self.test_data = self.test_data[start_idx:]
        self.dt_range_min, self.dt_range_max = estimate_relevant_scales(
            [train_dts], model_in_len
        )
        # Note that mean & sd don't account for the BLANK tokens. And they are
        # only calculated over the training data.
        self.dt_mean = np.mean(train_dts)
        self.dt_sd = np.std(train_dts)
        self.dt_min = np.min(train_dts)
        self.dt_max = np.max(train_dts)
        eps = 1e-10
        log_diff_t = np.log(train_dts + eps)
        self.log_dt_mean = np.mean(log_diff_t)
        self.log_dt_sd = np.std(log_diff_t)
        (
            self.train_log_probs,
            self.val_log_probs,
            self.test_log_probs,
        ) = np.split(log_prob_seq, split_borders)
        (
            self.mean_train_log_prob,
            self.mean_val_log_prob,
            self.mean_test_log_prob,
        ) = [
            np.mean(x)
            for x in [
                self.train_log_probs,
                self.val_log_probs,
                self.test_log_probs,
            ]
        ]
        self.density_interval_len = (
            self.INTERVAL_FRACTION_OF_MEAN * self.dt_mean
        )

    def train_dts_flat(self):
        """Return all training dts as a single array."""
        return self.train_dts

    @staticmethod
    def to_ds(data):
        x = einops.rearrange(
            torch.tensor(data[:, :-1], dtype=torch.float32), "n s -> n s 1"
        )
        # n s   (specifically not n s 1)
        mask = torch.ones_like(torch.tensor(data[:, :-1], dtype=torch.float))
        y = einops.rearrange(
            torch.tensor(data[:, 1:], dtype=torch.float32),
            "n s -> n s 1",
        )
        return torch.utils.data.TensorDataset(x, mask, y)

    def train_ds(self) -> torch.utils.data.Dataset:
        return kdai.datasets.LongerDataset.repeat(
            self.to_ds(self.train_data), self.train_repeats
        )

    def val_ds(self) -> torch.utils.data.Dataset:
        return self.to_ds(self.val_data)

    def test_ds(self) -> torch.utils.data.Dataset:
        return self.to_ds(self.test_data)


def gen_poisson_events2(mu, t_max, rng=None):
    """Generates event times following a Poisson process, until `t_max`.

    Returns:
        np.ndarray: 1D array of event times.

    """
    if rng is None:
        rng = np.random.default_rng(seed=124)
    res = []
    lo = 0
    while True:
        t = lo + rng.exponential(1 / mu)
        if t > t_max:
            break
        res.append(t)
        lo = t
    if len(res) == 0:
        raise RuntimeError("No events generated.")
    res = np.array(res)
    return res


def gen_poisson(n_events, rng=None) -> Tuple[np.ndarray, np.ndarray]:
    """Generates N events following a Poisson process."""
    mu = 1.0
    if rng is None:
        rng = np.random.default_rng(seed=DEFAULT_SEED)
    dts = scipy.stats.expon.rvs(
        loc=0, scale=1 / mu, size=n_events, random_state=rng
    )
    log_probs = scipy.stats.expon.logpdf(dts, loc=0, scale=1 / mu)
    ts = dts.cumsum()
    return ts, log_probs


def gen_nonstationary_poisson(
    n_events,
    period=20000,
    rng=None,
    c=None,
) -> Tuple[np.ndarray, np.ndarray]:
    """Generates events from a non-stationary Poisson process.

    Uses Algorithm 1 from page 7 of Lewis and Shedler, 1979 (thinning).

    See: https://www.math.fsu.edu/~ychen/research/Thinning%20algorithm.pdf

    The intensity function is sinusoidal:

        Intensity function = sin(2 * pi * t / period) + 1

    The algorithm works as follows:
       - take a value that is at least as big as the maximum of the intensity
         function. In our case, we will use 2.01. Alternative is to use 2, then
         scale λ function by a factor like 0.99.
         Denote this as lambda_max.
       - generate event intervals from an exponential dist with rate lambda_max.
       - for each event interval, generate a uniform random number in
         [0, lambda_max]. If the intensity function at the event time is
         greater than the random number, we will accept, else reject.

    Original code: https://github.com/omitakahiro/NeuralNetworkPointProcess/blob/master/code.ipynb

    Args:
        n_events: Number of events to generate.
        period: Period of the sinusoidal intensity function.
        rng (np.random.Generator): Random number generator. If None, use
            default_rng with fixed seed.

    Returns:
        np.ndarray: 1D array of event times.
        float: mean log-likelihood of the event times.
    """
    if rng is None:
        rng = np.random.default_rng(seed=DEFAULT_SEED)
    λ_upper_bound = 2.01
    if c is None:
        c = rng.random() * period
    # Intensity (denoted as λ)
    λ = lambda t: np.sin(2 * np.pi * (t + c) / period) + 1
    # Integrated intensity
    Λ = lambda t1, t2: -period / (2 * np.pi) * (
        np.cos(2 * np.pi * (t2 + c) / period)
        - np.cos(2 * np.pi * (t1 + c) / period)
    ) + (t2 - t1)

    max_tries = 4
    for _ in range(max_tries):
        # There should usually be enough points, but if not, try again.
        maybe_enough_points = n_events * 10
        # Generate candidate event times.
        T = rng.exponential(size=maybe_enough_points).cumsum() / λ_upper_bound
        r = rng.random(maybe_enough_points)
        accept_idxs = (r * λ_upper_bound) < λ(T)
        if accept_idxs.sum() >= n_events:
            break
    else:
        raise ValueError(
            "Could not generate enough events. Implementation is not robust."
        )
    # Only take as many events as needed.
    event_times = T[accept_idxs][:n_events]
    log_S = [-Λ(a, b) for a, b in pairwise([0] + event_times.tolist())]
    # f = λ * S   =>  log(f) = log(λ) + log(S)
    log_probs = np.log(λ(event_times)) + log_S
    return event_times, log_probs


def gen_stationary_renewal(n_events, rng=None) -> Tuple[np.ndarray, np.ndarray]:
    """Generates events from a stationary renewal process.

    Returns:
        np.ndarray: 1D array of event times.
        float: mean log-likelihood of the event times.
    """
    if rng is None:
        rng = np.random.default_rng(seed=DEFAULT_SEED)
    s = np.sqrt(np.log(6 * 6 + 1))
    mu = -s * s / 2
    tau = scipy.stats.lognorm.rvs(
        s=s, scale=np.exp(mu), size=n_events, random_state=rng
    )
    lpdf = scipy.stats.lognorm.logpdf(tau, s=s, scale=np.exp(mu))
    T = tau.cumsum()
    return T, lpdf


def gen_nonstationary_renewal(
    n_events, period=20000, rng=None
) -> Tuple[np.ndarray, np.ndarray]:
    """Generates events from a non-stationary renewal process.

    Returns:
        np.ndarray: 1D array of event times.
        float: mean log-likelihood of the event times.
    """
    if rng is None:
        rng = np.random.default_rng(seed=DEFAULT_SEED)
    c = rng.random() * period
    amp = 0.99
    l_t = lambda t: np.sin(2 * np.pi * (t + c) / period) * amp + 1
    l_int = lambda t1, t2: -period / (2 * np.pi) * (
        np.cos(2 * np.pi * (t2 + c) / period)
        - np.cos(2 * np.pi * (t1 + c) / period)
    ) * amp + (t2 - t1)

    T = []
    lpdf = []
    x = 0

    k = 4
    rs = scipy.stats.gamma.rvs(k, size=n_events, random_state=rng)
    lpdfs = scipy.stats.gamma.logpdf(rs, k)
    rs = rs / k
    lpdfs = lpdfs + np.log(k)

    for i in range(n_events):
        x_next = scipy.optimize.brentq(
            lambda t: l_int(x, t) - rs[i], x, x + 1000
        )
        l = l_t(x_next)
        T.append(x_next)
        lpdf.append(lpdfs[i] + np.log(l))
        x = x_next
    T = np.array(T)
    lpdf = np.array(lpdf)
    return T, lpdf


def gen_self_correcting(n_events, rng=None) -> Tuple[np.ndarray, np.ndarray]:
    """Generates events from a self-correcting process.

    Returns:
        np.ndarray: 1D array of event times.
        float: mean log-likelihood of the event times.
    """
    if rng is None:
        rng = np.random.default_rng(seed=DEFAULT_SEED)

    def self_correcting_process(mu, alpha, n):
        t, x = 0, 0
        T, log_l, Int_l = [], [], []

        for i in range(n):
            e = rng.exponential()
            tau = (
                np.log(e * mu / np.exp(x) + 1) / mu
            )  # e = ( np.exp(mu*tau)- 1 )*np.exp(x) /mu
            # Skip zero-length intervals.
            if tau == 0:
                continue
            t = t + tau
            T.append(t)
            x = x + mu * tau
            log_l.append(x)
            Int_l.append(e)
            x = x - alpha

        log_prob = np.array(log_l) - np.array(Int_l)
        return np.array(T), log_prob

    T, log_prob = self_correcting_process(1, 1, n_events)
    return T, log_prob


def gen_hawkes1(n_events, rng=None) -> Tuple[np.ndarray, np.ndarray]:
    """Generates events from a Hawkes process.

    Returns:
        np.ndarray: 1D array of event times.
        float: mean log-likelihood of the event times.
    """
    T, log_probs = simulate_hawkes(n_events, 0.2, [0.8, 0.0], [1.0, 20.0], rng)
    return T, log_probs


def gen_hawkes2(n_events, rng=None) -> Tuple[np.ndarray, np.ndarray]:
    """Generates events from a Hawkes process.

    Returns:
        np.ndarray: 1D array of event times.
        float: mean log-likelihood of the event times.
    """
    T, log_probs = simulate_hawkes(n_events, 0.2, [0.4, 0.4], [1.0, 20.0], rng)
    return T, log_probs


def simulate_hawkes(n, mu, alpha, beta, rng=None):
    """Simulate a Hawkes process."""
    if rng is None:
        rng = np.random.default_rng(seed=DEFAULT_SEED)
    T = []
    LL = []

    x = 0
    l_trg1 = 0
    l_trg2 = 0
    l_trg_Int1 = 0
    l_trg_Int2 = 0
    mu_Int = 0
    count = 0

    while 1:
        l = mu + l_trg1 + l_trg2
        step = rng.exponential() / l
        # Skip zero-length intervals.
        if x == x + step:
            continue
        x = x + step

        l_trg_Int1 += l_trg1 * (1 - np.exp(-beta[0] * step)) / beta[0]
        l_trg_Int2 += l_trg2 * (1 - np.exp(-beta[1] * step)) / beta[1]
        mu_Int += mu * step
        l_trg1 *= np.exp(-beta[0] * step)
        l_trg2 *= np.exp(-beta[1] * step)
        l_next = mu + l_trg1 + l_trg2

        if rng.random() < l_next / l:  # accept
            T.append(x)
            LL.append(np.log(l_next) - l_trg_Int1 - l_trg_Int2 - mu_Int)
            l_trg1 += alpha[0] * beta[0]
            l_trg2 += alpha[1] * beta[1]
            l_trg_Int1 = 0
            l_trg_Int2 = 0
            mu_Int = 0
            count += 1
            if count == n:
                break

    ts, log_probs = np.array(T), np.array(LL)
    if np.any(np.diff(ts) <= 0):
        # The cause: x2 - x1 = 0 even though x2 != x1 is not a bug, but a
        # characteristic of floating point arithmetic. It could be argued that
        # models should be able to handle this, which might involve assigning a
        # probability representative of the smallest possible time difference,
        # which will depend on the float precision.
        _logger.warning(
            "Event times are not all strictly increasing."
            f" {np.flatnonzero(np.diff(ts) <= 0)}"
        )
    return ts, log_probs
