from collections import namedtuple
from collections import deque
import concurrent.futures
import itertools
import json
import logging
import math
import pathlib
import bisect
from typing import (
    Deque,
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    TypeAlias,
    Any,
)
import numpy as np
import scipy
import scipy.signal
import bidict
import datetime
import torch
import kdtpp
import kdtpp.spikedistance as sdf
import kdai
import kdai.train
import kdai.datasets
import einops
import functools


_logger = logging.getLogger(__name__)


ELECTRODE_FREQ = 17852.767845719834  # Hz
NUM_STIMULUS_LEDS = 4

REC_IDS_FILENAME = "recording_ids.json"
REC_CELL_IDS_FILENAME = "recording_cell_ids.json"


RecCellId: TypeAlias = Tuple[str, int]
# A dictionary, recording-name <-> id
RecIds: TypeAlias = bidict.BidirectionalMapping[str, int]
# A dictionary, (recording-name, cell_id) <-> global cell id
RecCellIds: TypeAlias = bidict.BidirectionalMapping[RecCellId, int]

Stimulus = namedtuple(
    "Stimulus", ["name", "wavelength", "channel", "display_hex"]
)

stimuli = [
    Stimulus("red", 630, 0, "#FE7C7C"),
    Stimulus("green", 530, 1, "#8AFE7C"),
    Stimulus("blue", 480, 2, "#7CFCFE"),
    Stimulus("uv", 420, 3, "#7C86FE"),
]
_stimulus_map = {s.name: s for s in stimuli}


def stimulus_by_name(name: str) -> Stimulus:
    """Return the stimulus info for the stimulus with the given name."""
    return _stimulus_map[name]


def ms_to_num_bins(time_ms, downsample_factor):
    res = time_ms * (ELECTRODE_FREQ / 1000) / downsample_factor
    return res


def num_bins_to_ms(num_bins, downsample_factor):
    res = num_bins * downsample_factor / (ELECTRODE_FREQ / 1000)
    return res


class CompressedSpikeRecording:
    """
    Data for a single recording, with spikes and stimulus stored as events.

    Storage as events means that a 10 bin long recording with a single spike
    at bin 5 would have spike data = [5], as opposed to
    [0,0,0,0,1,0,0,0,0,0]. The recordings are stored in this format anyway,
    so this class is mostly just a wrapper around some Pandas dataframes. The
    SpikeRecording class is the decompressed version of this class.

    Some benefits of this class:

        - We are not sure what changes to the underlying dataframe we will be
          making in future, so this class provides a layer of abstraction.
        - The recording data is split into a triplet: (stimulus pattern,
          stimulus recording, spike recording) which typically needs to be
          moved around together. So combining them in a single place makes
          sense.
        - It's easy to make mistakes querying dataframes with the Pandas API.
          The queries can be done once here.


    Future work:

        - Currently, it's assumed that the stimulus pattern is full-field.
          Switching to a 2D stimulus might require some changes. At least we
          will need to decide the X-Y coordinate order and (0,0) position.
    """

    def __init__(
        self,
        name: str,
        stimulus_pattern: np.ndarray,
        stimulus_events: np.ndarray,
        spike_events: List[np.ndarray],
        cell_ids: List[int],
        sensor_sample_rate: float,
        num_sensor_samples: int,
        cell_gids: Optional[List[int]] = None,
    ):
        if len(spike_events) != len(cell_ids):
            raise ValueError(
                f"Mismatch between number of cell-spikes "
                f"({len(spike_events)}) and number of cell ids "
                f"({len(cell_ids)})."
            )
        self.name = name
        self.stimulus_pattern = stimulus_pattern
        self.cell_ids = cell_ids
        self.stimulus_events = stimulus_events
        self.spike_events = spike_events
        self.sensor_sample_rate = sensor_sample_rate
        self.num_sensor_samples = num_sensor_samples
        self.cell_gids = cell_gids if cell_gids is not None else cell_ids

    def __str__(self):
        res = (
            f"Recording: {self.name}, "
            f"sensor sample rate: {self.sensor_sample_rate} Hz, "
            f"num samples: {self.num_sensor_samples}, "
            f"duration: {self.duration():.1f} seconds, "
            f"stimulus pattern shape: {self.stimulus_pattern.shape},"
            f"num cells: {len(self.cell_ids)}."
        )
        return res

    def duration(self):
        """Duration of the recording in seconds."""
        return self.num_sensor_samples / self.sensor_sample_rate

    def cells(self, cell_ids: Set[int]) -> "CompressedSpikeRecording":
        """Returns a new recording with only the specified cells."""
        if not cell_ids.issubset(self.cell_ids):
            raise ValueError(
                f"Cell ids ({cell_ids}) are not a subset of "
                f"the cell ids in this recording ({self.cell_ids})."
            )

        new_idxs, new_cids = zip(
            *((i, c) for i, c in enumerate(self.cell_ids) if c in cell_ids)
        )
        spike_events = [self.spike_events[i] for i in new_idxs]
        cell_gids = [self.cell_gids[i] for i in new_idxs]
        return CompressedSpikeRecording(
            self.name,
            self.stimulus_pattern,
            self.stimulus_events,
            spike_events,
            list(new_cids),
            self.sensor_sample_rate,
            self.num_sensor_samples,
            cell_gids,
        )

    def num_cells(self) -> int:
        return len(self.cell_ids)

    def filter_cells(
        self,
        min_rate: Optional[float] = None,
        max_rate: Optional[float] = None,
        min_count: Optional[int] = None,
    ) -> "CompressedSpikeRecording":
        """Filter out cells with spike rates outside the given range.

        A new recording is returned.
        """
        matching_cells = set()

        def _spike_rate(spike_events):
            assert len(spike_events.shape) == 1
            res = len(spike_events) / self.duration()
            return res

        for i in range(len(self.spike_events)):
            min_rate_match = (
                min_rate is None
                or _spike_rate(self.spike_events[i]) >= min_rate
            )
            max_rate_match = (
                max_rate is None
                or _spike_rate(self.spike_events[i]) <= max_rate
            )
            min_count_match = (
                min_count is None or len(self.spike_events[i]) >= min_count
            )
            is_match = min_rate_match and max_rate_match and min_count_match
            if is_match:
                matching_cells.add(self.cell_ids[i])
        return self.cells(matching_cells)

    def cid_to_idx(self, cid: int) -> int:
        return self.cell_ids.index(cid)

    def to_json(self):
        res = {
            "name": self.name,
            "stimulus_pattern": self.stimulus_pattern.tolist(),
            "stimulus_events": self.stimulus_events.tolist(),
            "spike_events": [s.tolist() for s in self.spike_events],
            "cell_ids": self.cell_ids,
            "sensor_sample_rate": float(self.sensor_sample_rate),
            "num_sensor_samples": int(self.num_sensor_samples),
            "_export_date": str(datetime.datetime.now()),
            "_version": "1.0.0",
            # We don't serialize gids—those are managed globally.
        }
        return res

    @staticmethod
    def from_json(data):
        return CompressedSpikeRecording(
            data["name"],
            np.array(data["stimulus_pattern"]),
            np.array(data["stimulus_events"]),
            [np.array(s) for s in data["spike_events"]],
            data["cell_ids"],
            data["sensor_sample_rate"],
            data["num_sensor_samples"],
        )

    @staticmethod
    def from_json_list(data):
        """Create a list of CompressedSpikeRecording objects from a JSON list."""
        res = []
        for d in data:
            res.append(CompressedSpikeRecording.from_json(d))
        return res


class SpikeRecording:
    """
    Data for a recording, with a spike and stimulus value for each time bin.

    The class allows indexing by time and cell, so an object may represent
    a part of a recording, and isn't limited to be a whole recording.

    A 10 bin long recording with a single spike at bin 5 would be have spike
    data = [0,0,0,0,1,0,0,0,0,0], as opposed to [5]. The stimulus is stored
    similarly.

    This class is mostly just a wrapper around the already decompressed data.
    It's useful for the same reasons as the CompressedSpikeRecording class.

    Future work:

        - I'm slowly adding functionality to make this class more array like.
          The idea here is that the spike data and stimulus data have a common
          time dimension, which it is useful index and slice over, but
          otherwise their dimensions don't match. xarray would be a suitable
          off-the-shelf solution for this. xarray is probably best, but while
          the required functionality of this class is still minimal, the
          current solution seems fine, and saves having to require knowledge of
          yet another indexing API in addition to the currently used numpy,
          Pandas and PyTorch.
    """

    name: str
    """
    Stimulus is also shape (time, channels)
    """
    stimulus: np.ndarray
    spikes: np.ndarray
    cell_ids: List[int]
    """
    The globally unique cell IDs. These will be the same as the cell
    IDs if the cell_gids is not provided when initializing the Recording.
    """
    cell_gids: List[int]
    sample_rate: float
    """
    Currently, just used to add split info to the sub recordings.
    """
    metadata: Dict[str, Any]

    def __init__(
        self,
        name,
        stimulus,
        spikes,
        cell_ids: List[int],
        sample_rate: float,
        cell_gids: Optional[List[int]] = None,
    ):
        if len(stimulus) != len(spikes):
            raise ValueError(
                f"Length of stimulus ({len(stimulus)}) and length"
                f" of response ({len(spikes)}) do not match."
            )
        if spikes.shape[1] != len(cell_ids):
            raise ValueError(
                f"Mismatch between number of cell-spikes "
                f"({spikes.shape[1]}) and number of cell ids "
                f"({len(cell_ids)})."
            )
        self.name = name
        self.stimulus = stimulus
        self.spikes = spikes
        self.cell_ids = cell_ids
        self.sample_rate = sample_rate
        self.cell_gids = cell_gids if cell_gids is not None else cell_ids
        self.metadata = {}

    def __str__(self):
        res = (
            f"Recording: {self.name}, "
            f"sample rate: {self.sample_rate:.1f} Hz "
            f"({1000/self.sample_rate:.3f} ms per sample), "
            f"duration: {self.duration():.1f} seconds, "
            f"stimulus shape: {self.stimulus.shape}, "
            f"spikes shape: {self.spikes.shape}, "
            f"num cells: {len(self.cell_ids)}."
        )
        return res

    def __len__(self):
        assert len(self.stimulus) == len(self.spikes)
        return len(self.stimulus)

    def duration(self) -> float:
        return self.stimulus.shape[0] / self.sample_rate

    def num_cells(self) -> int:
        return len(self.cell_ids)

    def __getitem__(self, key):
        """Return a new recording with only data for the given time bin."""
        return SpikeRecording(
            self.name,
            self.stimulus[key],
            self.spikes[key],
            self.cell_ids,
            self.sample_rate,
            self.cell_gids,
        )

    def cells(self, cell_ids: Set[int]):
        """Returns a new recording with only the specified cells."""
        if not cell_ids.issubset(self.cell_ids):
            raise ValueError(
                f"Cell ids ({cell_ids}) are not a subset of "
                f"the cell ids in this recording ({self.cell_ids})."
            )
        # Old. Doesn't maintain original cell order.
        # cell_indices = [self.cell_ids.index(i) for i in cell_ids]
        new_idxs, new_ids = zip(
            *((i, c) for i, c in enumerate(self.cell_ids) if c in cell_ids)
        )
        spikes = self.spikes[:, new_idxs]
        gids = [self.cell_gids[i] for i in new_idxs]
        return SpikeRecording(
            self.name,
            self.stimulus,
            spikes,
            list(new_ids),
            self.sample_rate,
            gids,
        )

    def extend(self, recording):
        """Add the given recording to the end of this one.

        This was added in order to create test, train and validation sets that
        pick their data from multiple parts of the same recording.
        """
        if self.sample_rate != recording.sample_rate:
            raise ValueError(
                f"Sample rates do not match. Got ({self.sample_rate}) and "
                f"({recording.sample_rate})."
            )
        if self.stimulus.shape[1] != recording.stimulus.shape[1]:
            raise ValueError(
                f"The stimulus must have the same number of LEDs. Got "
                f"({self.stimulus.shape[1]}) and "
                f"({recording.stimulus.shape[1]})."
            )
        if not np.array_equal(self.cell_ids, recording.cell_ids):
            raise ValueError(
                f"Cell ids do not match. Got ({self.cell_ids}) and "
                f"({recording.cell_ids})."
            )
        if not np.array_equal(self.cell_gids, recording.cell_gids):
            raise ValueError(
                f"Cell global ids do not match. Got ({self.cell_gids}) "
                f"and ({recording.cell_gids})."
            )
        self.stimulus = np.concatenate((self.stimulus, recording.stimulus))
        self.spikes = np.concatenate((self.spikes, recording.spikes))
        return self

    def spike_snippets(self, cid: int, total_len: int, post_spike_len: int):
        res = spike_snippets(
            self.stimulus,
            compress_spikes(self.spikes[:, self.cell_ids.index(cid)]),
            total_len,
            post_spike_len,
        )
        return res

    def all_spike_snippets(self, total_len: int, post_spike_len: int):
        snippets_by_cell = {}
        # Can this be done in a single call?
        # Could make compress_spikes operate on 2d array, then sent all
        # spikes to spike_snippets then split with np.split().
        for c in range(self.spikes.shape[1]):
            snippets_by_cell[self.cell_ids[c]] = spike_snippets(
                self.stimulus,
                compress_spikes(self.spikes[:, c]),
                total_len,
                post_spike_len,
            )
        return snippets_by_cell

    def cid_to_idx(self, cid: int) -> int:
        return self.cell_ids.index(cid)

    def time_until_spike(self):
        """Convert the 0/1 spike data to an array of times until next spike.

        The timesteps at or beyond the last spike are set to negative values
        representing the number of timesteps until the end of the recording.
        """
        # We used to do the following:
        # cells = (self.spikes[:,i] for i in range(self.spikes.shape[1]))
        # with concurrent.futures.ProcessPoolExecutor() as executor:
        #     dts_cols = list(
        #         executor.map(sdf.time_until2, cells)
        #     )
        # dts = np.stack(dts_cols, axis=1)
        # But now we use the vectorized version.
        dts = sdf.time_until2vec(self.spikes)
        return dts


# A [train, val, test], where each of train, val and test are a list of
# recording parts.
ContiguousChunks: TypeAlias = List[SpikeRecording]
RecordingTrainValTest: TypeAlias = Tuple[
    ContiguousChunks, ContiguousChunks, ContiguousChunks
]


def split(recording: SpikeRecording, split_ratio: Sequence[int]):
    """Split a recording into multiple recordings.

    Args:
        split_ratio: a list of weightings that determines how much data to
            give each split. For example, you might use the triplet (3, 1, 1)
            to create a train-val-test split with 60% of the data for training,
            20% for validation, and 20% for testing.
    """
    if len(split_ratio) < 2:
        raise ValueError("Can't split a recording into fewer than 2 parts.")
    if not all([r > 0 for r in split_ratio]):
        raise ValueError(f"Split ratios must be positive. Got ({split_ratio}).")
    divisions = sum(split_ratio)
    num_per_division, remainder = divmod(len(recording), divisions)
    splits = []
    slice_start = slice_end = 0
    for i in range(len(split_ratio)):
        if i == 0:
            # Give all of the remainder to the first split.
            slice_end = split_ratio[i] * num_per_division + remainder
        else:
            slice_end += split_ratio[i] * num_per_division
        s =recording[slice_start:slice_end]
        s.metadata["split"] = i
        s.metadata["n_splits"] = len(split_ratio)
        splits.append(s)
        slice_start = slice_end
    total_len = sum([len(s) for s in splits])
    assert total_len == len(
        recording
    ), f"Split lengths do not match ({total_len}) vs ({len(recording)})."
    return splits


def mirror_split(recording: SpikeRecording, split_ratio: Sequence[int]):
    """Split data from "outside-in".

    This approach of splitting the data tries to address the issue that the
    response of the retina will change over time. Training the model with the
    earlier responses and validating or testing with the later responses will
    likely lead to reduced accuracy that is not the fault of the model.

    One way to ameliorate this issue is to give each split a piece of the
    earlier data and a piece of the later data. There is still some flexibility
    in how to make this choice. If there are three splits (train, val, test),
    then they will be made as follows:

        +-------+------+----+----+------+-------+
        | train | val  |  test   | val  | train |
        +-------+------+----+----+------+-------+

    The benefit of this approach is that each split is not exposed solely to
    one end of the data. Additionally, the test data, which will often appear
    in reports will still exist as a continuous block of data.

    Args:
        split_ratio: a list of weightings that determines how much data to
            give each split. For example, you might use the triplet (3, 1, 1)
            to create a train-val-test split with 60% of the data for training,
            20% for validation, and 20% for testing.
    """
    if not all([r > 0 for r in split_ratio]):
        raise ValueError(f"Split ratios must be positive. Got ({split_ratio}).")
    recording_half1 = recording[: len(recording) // 2]
    recording_half2 = recording[len(recording) // 2 :]
    assert len(recording_half1) + len(recording_half2) == len(recording)
    splits_half1 = split(recording_half1, split_ratio)
    splits_half2 = split(recording_half2, tuple(reversed(split_ratio)))
    splits = [
        s1.extend(s2) for (s1, s2) in zip(splits_half1, reversed(splits_half2))
    ]
    total_len = sum([len(s) for s in splits])
    assert total_len == len(
        recording
    ), f"Split lengths do not match ({total_len}) vs ({len(recording)})."
    # Finally, to match the datatype of mirror_split2, each element is a list.
    splits = [[s] for s in splits]
    return splits


def mirror_split2(recording: SpikeRecording, split_ratio: Sequence[int]):
    """Split data from "outside-in".

    Same as mirror_split, but returns lists of recordings per-split instead of
    concatenating them together.

    Args:
        split_ratio: a list of weightings that determines how much data to
            give each split. For example, you might use the triplet (3, 1, 1)
            to create a train-val-test split with 60% of the data for training,
            20% for validation, and 20% for testing.
    """
    if not all([r > 0 for r in split_ratio]):
        raise ValueError(f"Split ratios must be positive. Got ({split_ratio}).")
    mirrored_ratios = []
    # [7, 2, 1] -> [7, 2, 2, 2, 7]
    mirrored_ratios.extend(split_ratio[:-1])
    mirrored_ratios.append(split_ratio[-1] * 2)
    mirrored_ratios.extend(reversed(split_ratio[:-1]))
    split_parts = split(recording, mirrored_ratios)
    assert len(split_parts) % 2 == 1, "Expecting odd number of splits"
    # [first, last], [second, second-last], ...
    # zip returns generator(tuple), and we want list[list[]].
    splits = [
        list(pair)
        for pair in zip(
            split_parts[: len(split_parts) // 2],
            reversed(split_parts[len(split_parts) // 2 :]),
        )
    ]
    splits.append([split_parts[len(split_parts) // 2]])
    total_len = sum([len(sp) for s in splits for sp in s])
    assert total_len == len(
        recording
    ), f"Split lengths do not match ({total_len}) vs ({len(recording)})."
    return splits


def remove_few_spike_cells(splits: List[List[SpikeRecording]], min_counts):
    """Skip cells that have few spikes in one of the data splits."""
    if len(splits) != len(min_counts):
        raise ValueError(
            "splits and min_counts should have the same shape. Got "
            f"({len(splits)} vs, {len(min_counts)})."
        )
    num_cids = len(splits[0][0].cell_ids)
    to_keep = np.ones(shape=num_cids, dtype=bool)
    for rs, min_count in zip(splits, min_counts):
        spikes = np.concatenate([r.spikes for r in rs])
        np.logical_and(np.sum(spikes, axis=0) >= min_count, to_keep, to_keep)
    filtered_splits = tuple(
        [s.cells(set(np.array(s.cell_ids)[to_keep])) for s in split_parts]
        for split_parts in splits
    )
    return filtered_splits


def _assert_dir_exists(p: pathlib.Path):
    if not p.exists():
        raise ValueError(f"Directory does not exist: ({p}).")
    if not p.is_dir():
        raise ValueError(f"Path is not a directory: ({p}).")


def _assert_file_exists(p: pathlib.Path):
    if not p.exists():
        raise ValueError(f"File does not exist: ({p}).")
    if not p.is_file():
        raise ValueError(f"Path is not a file: ({p}).")


def _rec_cell_ids(recs: Iterable[CompressedSpikeRecording]):
    """Create a rec_name -> cell_id -> ID mapping."""
    res = {}
    tally = 0
    for r in recs:
        res[r.name] = {}
        for c_id in r.cell_ids:
            res[r.name][c_id] = tally
            tally += 1
    return res


def save_id_info(
    recs: Iterable[CompressedSpikeRecording],
    rec_id_map: Dict[str, int],
    data_dir: pathlib.Path | str,
):
    """
    Saves two dictionaries to disk:
        1. recording-name -> recording-id
        2. recording-name -> cell-id -> ID.

    The purpose of the loading and saving of id information is to maintain a
    consistent way of referring to recordings and cells.
    An example of where this is needed is when training a multi-cell model
    on a subset of the data. This subset could have been selected manually
    and/or automatically by the filtering procedure that removes unsuitable
    cells, such as those with too few spikes. It is important to be able to
    obtain the embeddings for the cells that were trained on, but not the
    ones that were not. Without a global ID, a snapshot of the data along with
    the filtering routine used while training would be needed.
    given to the cell data by the spike sorter is also not app
    """
    data_dir = pathlib.Path(data_dir)
    _assert_dir_exists(data_dir)
    with open(data_dir / REC_IDS_FILENAME, "w") as f:
        json.dump(rec_id_map, f)
    recs = sorted(recs, key=lambda r: rec_id_map[r.name])
    with open(data_dir / REC_CELL_IDS_FILENAME, "w") as f:
        json.dump(_rec_cell_ids(recs), f)


def has_id_info(data_dir: pathlib.Path | str) -> bool:
    data_dir = pathlib.Path(data_dir)
    dir_exists = data_dir.exists() and data_dir.is_dir()
    rec_id_file_exists = (data_dir / REC_IDS_FILENAME).exists()
    cell_id_file_exists = (data_dir / REC_CELL_IDS_FILENAME).exists()
    res = dir_exists and rec_id_file_exists and cell_id_file_exists
    return res


def load_id_info(data_dir: pathlib.Path | str) -> Tuple[RecIds, RecCellIds]:
    """Loads two dictionaries from disk, stored as JSON.

    The two mappings are:
        1. recording-name -> recording-id
        2. (recording-name, cell-id) -> ID.
    """
    data_dir = pathlib.Path(data_dir)
    _assert_dir_exists(data_dir)
    rec_id_path = data_dir / REC_IDS_FILENAME
    rec_cell_id_path = data_dir / REC_CELL_IDS_FILENAME
    _assert_file_exists(rec_id_path)
    _assert_file_exists(rec_cell_id_path)
    with open(rec_id_path, "r") as f:
        rec_ids = json.load(f)
        # Convert from normal to bi-directional dictionary.
        rec_ids = bidict.bidict(rec_ids)
    with open(rec_cell_id_path, "r") as f:
        rec_cell_ids = json.load(f)
    # The recording-cell map will be flattened to take tuples.
    flat_rec_cell_ids = bidict.bidict()
    for r_name, d1 in rec_cell_ids.items():
        for c_id, c_id_flat in d1.items():
            # Don't forget to convert c_id to int. It became a string as
            # JSON only allows string dictionary keys.
            c_id = int(c_id)
            assert type(c_id) == type(c_id_flat) == int
            flat_rec_cell_ids[(r_name, c_id)] = c_id_flat
    return rec_ids, flat_rec_cell_ids


def decompress_recordings(
    recordings: Iterable[CompressedSpikeRecording],
    downsample: int = 1,
    num_workers=10,
) -> Deque[SpikeRecording]:
    """
    Decompress multiple recordings.

    For when you want multithreading and don't want to handle futures, use this.
    The downside is that you can't chain the results; you get the results all
    in one go. Chaining is useful if the next step involves filtering, and the
    decompressed recordings are too large to fit in memory.
    """

    def _decompress(rec):
        return decompress_recording(rec, downsample)

    with concurrent.futures.ThreadPoolExecutor(
        max_workers=num_workers
    ) as executor:
        res = deque(executor.map(_decompress, recordings))
        executor.shutdown(wait=True)
    return res



def decompress_recording(
    recording: CompressedSpikeRecording, downsample: int
) -> SpikeRecording:
    """
    Decompress a compressed recording.

    The result holds numpy arrays where the first dimension is time.

    Without downsampling, a single recording take up more than 1 gigabyte of
    memory. It's quite convenient to set the  downsample to 18, as this will
    cause the resulting sample rate to be 991.8 Hz, which is the closest you
    can get to 1 kHz, given the 3Brain electrode's original frequency.
    """
    sample_rate = recording.sensor_sample_rate / downsample
    stimulus = decompress_stimulus(
        recording.stimulus_pattern,
        recording.stimulus_events,
        recording.num_sensor_samples,
        downsample,
    )
    spikes = np.stack(
        [
            decompress_spikes(s, recording.num_sensor_samples, downsample)
            for s in recording.spike_events
        ],
        axis=1,
    )
    res = SpikeRecording(
        recording.name,
        stimulus,
        spikes,
        recording.cell_ids,
        sample_rate,
        recording.cell_gids,
    )
    return res


def decompress_stimulus(
    stimulus_pattern: np.ndarray,
    trigger_events: np.ndarray,
    total_length: int,
    downsample: int,
) -> np.ndarray:
    if trigger_events[0] != 0:
        raise ValueError(
            "The trigger events are expected to start at zero, "
            f"but the first trigger was at ({trigger_events[0]})."
        )
    if len(trigger_events) > len(stimulus_pattern):
        raise ValueError(
            "Recorded stimulus is longer than the stimulus "
            f"pattern. ({len(trigger_events)} > {len(stimulus_pattern)})"
        )
    # TODO: check assumption! Assuming that the stimulus does not continue
    # after the last trigger event. This makes the last trigger special in that
    # it doesn't mark the start of a new stimulus output.
    # TODO: Marvin says that the last trigger event isn't the end of the
    # recording, but the last stimulus event before the end of the recording.
    # This introduces issues in that it's not clear how long this final
    # stimulus event is. I'm keeping the functionality as it is now, as I think
    # a better solution is to encode the recording end in the last stimulus
    # event, which would make the existing code below work fine.
    _logger.info(
        f"Starting: decompressing stimulus. Resulting shape ({total_length})."
    )
    # If this becomes a bottleneck, there are some tricks to reach for:
    # https://stackoverflow.com/questions/60049171/fill-values-in-numpy-array-that-are-between-a-certain-value
    num_channels = stimulus_pattern.shape[1]
    res = np.empty(shape=(total_length, num_channels))
    # Pair each trigger with the next trigger, and do this for all but the last.
    slices = np.stack((trigger_events[:-1], trigger_events[1:]), axis=1)
    for idx, s in enumerate(slices):
        res[np.arange(*s)] = stimulus_pattern[idx]
    last_trigger = trigger_events[-1]
    res[last_trigger:] = stimulus_pattern[len(slices)]
    _logger.info(
        f"Finished: decompressing stimulus. "
        f"The last trigger was at ({trigger_events[-1]}) making its "
        f"duration ({res.shape[0]} - {trigger_events[-1]} = "
        f"{res.shape[0] - trigger_events[-1]}) samples."
    )
    res = downsample_stimulus(res, downsample)
    return res


def recording_splits(
    recordings: Iterable[CompressedSpikeRecording],
    downsample: int,
    split_ratio: Tuple[int, int, int],
    num_workers: int,
) -> List[RecordingTrainValTest]:
    """Create N train/val/test splits from a list of N recordings.

    Args:
        recordings: List of recordings from which to create splits.
        downsample: Factor to downsample by.
        num_workers: Max number of thread workers to use.

    Returns:
        List of N train/val/test splits.
    """
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=num_workers
    ) as executor:
        fns = [
            functools.partial(decompress_recording, downsample=downsample),
            functools.partial(mirror_split2, split_ratio=split_ratio),
        ]
        data = recordings
        for fn in fns:
            data = executor.map(fn, data)
        res = list(data)
    return res


def rebin_spikes(spike_idxs: np.ndarray, downsample_factor: int) -> np.ndarray:
    """Calculate the spike indices after downsampling.

    Args:
        spike_idxs: The spike indices before downsampling.

    This method is very simple: just floor divide the indicies. It's good to
    have a dedicated function just so we are clear what the behaviour is,
    and to ensure that we are doing it consistently.
    """
    res = np.floor_divide(spike_idxs, downsample_factor)
    return res


def decompress_spikes(
    spikes: np.ndarray | np.ma.MaskedArray,
    num_sensor_samples: int,
    downsample_factor: int = 1,
) -> np.ndarray:
    """
    Fills an integer array counting the number of spikes that occurred.

    Setting downsample_factor to an integer greater than 1 will result in
    the spikes being counted in larger bin sizes that the original sensor
    sample period. So we are not talking about signal downsampling, Nyquist
    rates etc., rather we are talking about histogram binning where the bin
    size is scaled by downsample_factor. This behaviour is similar to Pandas's
    resample().sum() pattern.

    Binning behaviour
    -----------------
    As only integer values are accepted for downsample_factor, the binning is
    achieved by floor division of the original spike index. Examples:
        1. Input: [0, 0, 0, 1, 1], downsample_factor=2, output: [0, 1, 1]
        1. Input: [0, 0, 0, 1, 1, 1], downsample_factor=2, output: [0, 1, 2]
    """
    if np.ma.isMaskedArray(spikes):
        spikes = spikes.compressed()
    downsampled_spikes = rebin_spikes(spikes, downsample_factor)
    res = np.zeros(
        shape=[
            math.ceil(num_sensor_samples / downsample_factor),
        ],
        dtype=int,
    )
    np.add.at(res, downsampled_spikes, 1)
    return res


def factors_sorted_by_count(
    n, limit: Optional[int] = None
) -> Tuple[Tuple[int, ...]]:
    """
    Calculates factor decomposition with sort and limit.

    This method is used to choose downsampling factors when a single factor
    is too large. The decompositions are sorted by the number of factors in a
    decomposition. With this in mind, when a factorization cannot be found
    that has all factors

    Args:
        n: The number to decompose.
        limit: The maximum number of factors allowed in a single decomposition.
            This is an inclusive limit. If there are no decompositions t
    """

    def _factors(n):
        # Use a set to avoid duplicate factors.
        res = {(n,)}
        f1 = n // 2
        while f1 > 1:
            f2, mod = divmod(n, f1)
            if not mod:
                res.add(tuple(sorted((f1, f2))))
                for a, b in itertools.product(_factors(f1), _factors(f2)):
                    sub_factors = tuple(sorted(a + b))
                    res.add(sub_factors)
            f1 -= 1
        return res

    factors = _factors(n)
    sorted_by_count = tuple(sorted(factors, key=lambda x: len(x)))
    # If there is a limit set, remove any factor decompositions that contain a
    # factor larger than the limit.
    if limit:
        factors_filtered = tuple(f for f in factors if max(f) <= limit)
        if not factors_filtered:
            _logger.info(
                f"No factor decomposition exists with factors under "
                f"{limit}. Returning the decomposition with the "
                f"most factors."
            )
            sorted_by_count = (sorted_by_count[-1],)
        else:
            sorted_by_count = factors_filtered
    return sorted_by_count


def downsample_stimulus(stimulus: np.ndarray, factor: int) -> np.ndarray:
    """
    Filter (low-pass) a stimulus and then decimate by a factor.

    This is needed to prevent aliasing.

    Resources on filtering
    ----------------------
    https://dsp.stackexchange.com/questions/45446/pythons-tt-resample-vs-tt-resample-poly-vs-tt-decimate
    https://dsp.stackexchange.com/questions/83696/downsample-a-signal-by-a-non-integer-factor
    https://dsp.stackexchange.com/questions/83889/decimate-a-signal-whose-values-are-calculated-not-stored?noredirect=1#comment176944_83889
    """
    if factor == 1:
        return stimulus
    time_axis = 0
    _logger.info(
        f"Starting: downsampling by {factor}. Initial length "
        f"{stimulus.shape[time_axis]:,}."
    )
    # SciPy recommends to never exceed 13 on a single decimation call.
    # See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.decimate.html
    MAX_SINGLE_DECIMATION = 13
    sub_factors = factors_sorted_by_count(factor, limit=MAX_SINGLE_DECIMATION)[
        0
    ]
    for sf in sub_factors:
        _logger.debug(f"Starting: decimating by {sf}")
        stimulus = scipy.signal.decimate(
            stimulus, sf, ftype="fir", axis=time_axis
        )
        _logger.debug(f"Finished: decimating by {sf}")
    _logger.info(
        f"Finished: downsampling. Resulting length "
        f"({stimulus.shape[time_axis]:,})."
    )
    return stimulus


def spike_window(
    spike_idxs: int | Sequence[int], total_len: int, post_spike_len: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculate the window endpoints around a spike, in samples of the stimulus.
    """
    if total_len < post_spike_len + 1:
        raise ValueError(
            f"Total snippet length must be at least 1 greater than the "
            f"post-spike length + 1. Got total_len ({total_len}) "
            f"and post_spike_len ({post_spike_len})."
        )
    spikes = np.array(spike_idxs)
    # Calculate the snippet start and end.
    # The -1 appears as we include the spike sample in the snippet.
    win_start = (spikes + post_spike_len) - (total_len - 1)
    win_end = spikes + post_spike_len + 1
    return win_start, win_end


def spike_snippets(
    stimulus: np.ndarray,
    spike_idxs: np.ndarray,
    total_len: int,
    post_spike_len: int,
) -> np.ndarray:
    """
    Return a subset of the stimulus around the spike points.

    Args:
        stimulus: The decompressed stimulus.
        spikes: Spike counts, re-binned to the stimulus sample rate.
        total_len: The length of the snippet in stimulus frames.
        post_spike_len: The number of frames to pad after the spike.
    Returns:
        A Numpy array of shape (spikes.shape[0], total_len, NUM_STIMULUS_LEDS).

    Note 1: The total length describes the snippet length inclusive of the post
        spike padding.
    Note 2: If a spike bin has 2 spikes, then the snippet they share will be
        added to the output twice.
    Note 3: If a spike happens early enough that the snippet would start before
        the stimulus, then the snippet will be padded with zeros. This applies
        to the end of the stimulus as well.

    Single spike example
    ====================

        frame #:  |   0   |   1   |   2   |   3   |   4   |   5   |   6   |   7   |
        ===========================================================================
        stimulus: |   0   |   1   |   1   |   0   |   0   |   1   |   0   |   0   |
                  |   1   |   1   |   0   |   0   |   1   |   0   |   0   |   0   |
                  |   1   |   1   |   1   |   0   |   0   |   0   |   0   |   1   |
                  |   0   |   0   |   0   |   0   |   1   |   0   |   1   |   1   |
        ===========================================================================
        spikes:   |   0   |   0   |   0   |   0   |   1   |   0   |   0   |   0   |

    The slice with parameters:
        - total length = 5
        - post spike length = 1

    Would be:

                  |   1   |   1   |   0   |   0   |   1   |
                  |   1   |   0   |   0   |   1   |   0   |
                  |   1   |   1   |   0   |   0   |   0   |
                  |   0   |   0   |   0   |   1   |   0   |
    """
    if np.ma.isMaskedArray(spike_idxs):
        raise ValueError(
            "spikes must be a standard numpy array, not a masked array."
        )
    # 1. Get the spike windows.
    win_start, _ = spike_window(spike_idxs, total_len, post_spike_len)
    # 2. Pad the stimulus in case windows go out of range.
    if np.any(win_start < 0) or np.any(
        win_start >= (stimulus.shape[0] - total_len)
    ):
        stimulus = np.pad(
            stimulus, ((total_len, total_len), (0, 0)), "constant"
        )
        # 3. Offset the windows, which is needed due to the padding.
        win_start += total_len
    # 4. Extract the slice.
    # The padded_stimulus is indexed by a list arrays of the form:
    #    (win_start[0], win_start[0]+1, win_start[0]+2, ..., win_start[0]+total_len)
    #    (win_start[1], win_start[1]+1, win_start[1]+2, ..., win_start[1]+total_len)
    #    ...
    #    (win_start[num_spikes-1], win_start[num_spikes-1]+1, win_start[num_spikes-1]+2, ..., win_start[num_spikes-1]+total_len)
    snippets = stimulus[np.asarray(win_start)[:, None] + np.arange(total_len)]
    return snippets


def compress_spikes(spikes: np.ndarray) -> np.ndarray:
    """Converts a 1D spike array to an array of spike indices.

    Example:
        [0, 0, 1, 0, 0, 2, 3] => [2, 5, 5, 6, 6, 6]
    """
    nonzero_idxs = np.squeeze(np.nonzero(spikes))
    # Alternative:
    #   return np.repeat(np.arange(len(spikes)), spikes)
    return np.repeat(nonzero_idxs, spikes[nonzero_idxs])


def labeled_spike_snippets(
    rec: CompressedSpikeRecording,
    snippet_len: int,
    snippet_pad: int,
    downsample: int = 1,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculate spike snippets for a recording, paired with the cell ids.

    Args:
        snippet_len: the length of the snippet.
        snippet_pad: the number of timesteps to include after the spike.

    Returns: two np.ndarrays tuple. The first element contains
    the spike snippets and the second element contains ids of the cells.
    """
    stimulus = decompress_stimulus(
        rec.stimulus_pattern,
        rec.stimulus_events,
        rec.num_sensor_samples,
        downsample,
    )

    # Gather the snippets and align them with a matching array of cell ids.
    snippets = []
    cell_ids = []
    for idx, cell_id in enumerate(rec.cell_ids):
        spike_idxs = rebin_spikes(rec.spike_events[idx], downsample)
        snips = spike_snippets(
            stimulus,
            spike_idxs,
            snippet_len,
            snippet_pad,
        )
        cell_ids.extend(
            [
                cell_id,
            ]
            * len(snips)
        )
        snippets.extend(snips)
    snippets = np.stack(snippets)
    cell_ids = np.array(cell_ids)
    return snippets, cell_ids


"""
--------------------------------------------------------------------------------
Pytorch dataset related.
--------------------------------------------------------------------------------
"""


def _ragged_ends(spikes, end_pad: Optional[int]) -> np.ndarray:
    """Calculates the per-cell last usable timestep.

    The return value is a 1D array where the i-th element is a long value
    representing the 1 past the index of the element that will appear as the
    last element of the last snippet for that cell.

    For some models, like LogMix, where a future spike can be arbitrarily far
    in the future, there is no need to include any information after the last
    spike, as it can't be used; the next spike either does not exist or not the
    val/test set, and we don't want this information bleeding into the training
    set. In comparison, the Distance model and Discrete model can both use the
    empty space after the last spike, up to some limit determined by their
    output length. The model can use right up to the end of the snippet minus
    the output length, here called end_pad. The larger the end_pad, the more
    restricted the model is, and the less of the recording it can use. For
    example, a Discrete model with 128 output bins can't use a snippet whose
    input ends closer than 128 samples before the end of the snippet, otherwise
    a spike in the following recording segment might need to be listed as the
    next spike in a bin other than the last [127, ∞) bin.

    Example: if you have out_resolution = 1, then end_pad = 0, as no matter
    where the next spike would fall in the subsequent segment, it would be
    registered in the catch all single bin. A more realistic example,
    if out_resolution is 128, then the pad is 127.

    Args:
        spikes (L, S): time x cells array of spike counts.
    """
    L, S = spikes.shape
    if L < S:
        _logger.warning(
            "Possible bug: number of timesteps is less than the number of cells"
            f"({L=}, {S=})."
        )
    increment_along_time = np.cumsum(spikes, axis=0)
    last_spike_idx = np.argmax(increment_along_time, axis=0)
    if end_pad:
        ends = np.maximum(last_spike_idx, spikes.shape[0] - end_pad)
    else:
        ends = last_spike_idx
    return ends


class SnippetDataset(torch.utils.data.Dataset):
    _num_strided_timesteps: int
    _stride: int
    num_cells: int

    def __init__(
        self,
        recording: SpikeRecording,
        snippet_len: int,
        stride: int = 1,
        shuffle_stride: bool = False,
    ):
        """
        Args:
            recording: the recording to extract snippets from.
            snippet_len: the number of bins to include in each snippet.
            stride: viewing the output as a window, stride is the number of
                bins to move the window forward by between samples.
            shuffle_stride: when False, start positions are skipped by the
                nature of stride. If True, start positions are uniformly sampled
                from the range [pos, pos + stride). This is useful as a means
                of tuning the batch size. When stride is 1, a single batch
                can be considered to contain a lot of very similar data. This
                can lead to models becoming overfitted before a single batch
                ends. Setting stride to a larger value can help with this.
                When setting stride > 1, we can jitter the start position as
                so as not to throw away data.
        """
        if snippet_len > len(recording):
            raise ValueError(
                f"Snippet length ({snippet_len}) is larger than "
                f"the recording length ({len(recording)})."
            )
        self.recording = recording
        self.snippet_len = snippet_len
        self.num_cells = len(recording.cell_ids)
        self.num_timesteps = len(recording) - snippet_len + 1
        self.shuffle_stride = shuffle_stride
        assert (
            self.num_timesteps > 0
        ), "Snippet length is longer than the recording."
        # _num_strided_timesteps is set in the setter for stride.
        self.stride = stride

    def __len__(self):
        """
        Calculates the number of samples in the dataset.
        """
        res = self._num_strided_timesteps * self.num_cells
        return res

    def _decode_index(self, index: int) -> Tuple[int, int, int]:
        """
        Decodes the index into the timestep and cell id.

        The data is effectively a 2D array with dimensions (time, cell).
        The index is the flattened index of this array, and so the timestep
        increases as the index increases and wraps to the next cell id when
        it reaches the end of the recording.
        """
        timestep_idx, cell_idx = kdai.datasets.decode_strided_wrap(
            index,
            len(self.recording),
            self.snippet_len,
            self.stride,
            self.shuffle_stride,
        )
        assert cell_idx < self.num_cells, f"{cell_idx} > {self.num_cells}"
        cell_id = self.recording.cell_gids[cell_idx]
        return timestep_idx, cell_idx, cell_id

    def __getitem__(self, idx):
        """
        Returns the snippet at the given index.
        """
        if idx >= len(self):
            raise IndexError()
        start_time_idx, cell_idx, cell_id = self._decode_index(idx)
        end_time_idx = start_time_idx + self.snippet_len
        assert end_time_idx <= len(
            self.recording
        ), f"{end_time_idx} <= {len(self.recording)}"
        rec = self.recording.stimulus[start_time_idx:end_time_idx].T
        spikes = self.recording.spikes[start_time_idx:end_time_idx, cell_idx]
        res = {"stimulus": rec, "spikes": spikes, "cell_id": cell_id}
        return res

    @property
    def stride(self):
        return self._stride

    @stride.setter
    def stride(self, stride: int):
        self._stride = stride
        self._num_strided_timesteps = kdai.datasets.num_windows(
            len(self.recording), self.snippet_len, stride
        )

    @property
    def num_strided_timesteps(self):
        return self._num_strided_timesteps


class SnippetΔtDataset(torch.utils.data.Dataset):
    """
    Same as SnippetDataset, but we also include the time until the next spike.

    dt convention
    -------------
    spikes               |
             — — — — — — — — — —
    t=0                |

    If a spike happens at the first output sample, then the next spike value
    is, by convention, 0. If you wish to interpret this as a time until, then
    you may wish to consider the interval [0, 1), or take a point sample like
    0.5.

    Additionally, we go to extra effort to avoid leaking information past then
    end of the snippet. This involves maintaining a per-cell sequence length, as
    the length is now determined by that cell's last spike in the snippet,
    and by how far past this cell the model is capable of using. For LogMix,
    only up to and including the last spike can be used; whereas the Discrete
    model can go right to the end of the snippet minus out_resolution timesteps.

    Previously on Failed Attempts:
    To prevent leaking information from the future, the time until next spike
    is infinity for all snippets after the last spike (the next spike might
    be known, but be part of the validation or test set).
    """

    _num_strided_timesteps: int
    _stride: int
    num_cells: int

    def __init__(
        self,
        recording: SpikeRecording,
        snippet_len: int,
        # This can be set for Discrete and Dist models to allow for more
        # information to be used. None is most restrictive. 0 is most
        # permissive.
        n_last_spike_suffix: Optional[int] = None,
        stride: int = 1,
        shuffle_stride: bool = False,
        augment: bool = False,
    ):
        """
        Args:
            recording: the recording to extract snippets from.
            snippet_len: the number of bins to include in each snippet.
            stride: viewing the output as a window, stride is the number of
                bins to move the window forward by between samples.
            shuffle_stride: when False, start positions are skipped by the
                nature of stride. If True, start positions are uniformly sampled
                from the range [pos, pos + stride). This is useful as a means
                of tuning the batch size. When stride is 1, a single batch
                can be considered to contain a lot of very similar data. This
                can lead to models becoming overfitted before a single batch
                ends. Setting stride to a larger value can help with this.
                When setting stride > 1, we can jitter the start position as
                so as to not throw away data.
        """
        if snippet_len > len(recording):
            raise ValueError(
                f"Snippet length ({snippet_len}) is larger than "
                f"the recording length ({len(recording)})."
            )
        self.recording = recording
        # dts is the steps until the spike _after_ the current timestep.
        # it's always greater than 0. This means than
        # snippet_len = model_input_len
        self.dts = self.recording.time_until_spike()
        assert (
            self.dts.shape == recording.spikes.shape
        ), f"{self.dts.shape=} != {recording.spikes.shape=}"
        self.snippet_len = snippet_len
        self.num_cells = len(recording.cell_ids)
        self.n_last_spike_suffix = n_last_spike_suffix
        self._ends = _ragged_ends(recording.spikes, self.n_last_spike_suffix)
        self._max_start_t = self._ends - self.snippet_len
        self.shuffle_stride = shuffle_stride
        self.augment = augment
        self._num_strided_snippets = None
        # _num_strided_timesteps is set in the setter for stride, which this
        # next statement will trigger.
        self.stride = stride

    def __len__(self):
        total_len = self._num_strided_snippets.sum()
        return total_len

    def _decode_index(self, index: int) -> Tuple[int, int, int]:
        """Decodes the index into the timestep and cell id."""
        timestep_idx, cell_idx = kdai.datasets.decode_strided_ragged_wrap(
            index,
            self._ends,
            self.snippet_len,
            self.stride,
            self.shuffle_stride,
        )
        assert cell_idx < self.num_cells, f"{cell_idx} > {self.num_cells}"
        cell_id = self.recording.cell_gids[cell_idx]
        return timestep_idx, cell_idx, cell_id

    def _augment_stimulus(self, stimulus):
        """
        Augment a stimulus portion of a sample.
        """
        NOISE_SD = 0.2
        NOISE_MU = 0.01
        STIM_MASK_RATE = 0.05
        MIXUP_P = 0.2
        MASK_VALUE = -3
        # Mixup
        mix_idx = np.random.randint(0, len(self))
        mix_t = self._decode_index(mix_idx)[0]
        mixup = self.recording.stimulus[mix_t : mix_t + stimulus.shape[0]]
        stimulus = stimulus * (1 - MIXUP_P) + mixup * MIXUP_P
        # Whole block scale.
        mu = 1.0
        sd = 0.10
        scale = np.random.normal(mu, sd, size=(1,))
        # Whole block offset.
        mu = 0.0
        sigma = 0.10
        offset_noise = np.random.normal(mu, sigma, size=(1,))
        # Per bin noise.
        bin_noise = np.random.normal(
            NOISE_MU,
            NOISE_SD,
            size=stimulus.shape,
        )
        stimulus = stimulus * scale + offset_noise
        stimulus += bin_noise
        # Mask some parts.
        mask_indicies = np.nonzero(
            np.random.binomial(1, p=STIM_MASK_RATE, size=len(stimulus))
        )
        stimulus[mask_indicies] = MASK_VALUE
        return stimulus

    def _augment_spikes(self, spikes):
        """
        Augment the spike portion of a sample.

        Call this on the model input portion of the spike data, and not the
        portion that we are trying to predict.
        """
        NOISE_JITTER = 4
        DROP_RATE = 0.08
        spike_indicies = np.nonzero(spikes)
        spikes[spike_indicies] = 0
        # Add jitter
        if NOISE_JITTER > 0:
            jitter = np.random.randint(
                -NOISE_JITTER, NOISE_JITTER, len(spike_indicies)
            )
            spike_indicies = np.clip(
                spike_indicies + jitter, 0, len(spikes) - 1
            )
            # Drop some spikes.
            new_spikes = np.random.binomial(
                1, p=(1 - DROP_RATE), size=len(spike_indicies)
            )
            spikes[spike_indicies] = new_spikes
        return spikes

    def __getitem__(self, idx):
        """Returns the snippet at the given index."""
        if idx >= len(self):
            raise IndexError()
        start_time_idx, cell_idx, cell_id = self._decode_index(idx)
        end_time_idx = start_time_idx + self.snippet_len
        assert end_time_idx <= len(
            self.recording
        ), f"{end_time_idx} <= {len(self.recording)}"
        stim = self.recording.stimulus[start_time_idx:end_time_idx]
        spikes = self.recording.spikes[start_time_idx:end_time_idx, cell_idx]
        if self.augment:
            stim = self._augment_stimulus(stim)
            spikes = self._augment_spikes(spikes)
        # Only single y value! No causal training supported.
        # dts = how many steps until next spike, excluding current position.
        # It's always > 0. We want to minus 1 and start from 0. If it helps,
        # to understand, it's like the spike information is stored one step
        # earlier, which is convenient, as the information for 
        next_spike = self.dts[end_time_idx - 1, cell_idx] - 1
        assert np.all(next_spike % 1 == 0), f"Expected integer. {next_spike=}"
        assert np.all(next_spike >= 0), f"{next_spike=}"
        res = {
            "stimulus": stim,
            "spikes": spikes,
            "next_spike": next_spike,
            "cell_id": cell_id,
        }
        return res

    @property
    def stride(self):
        return self._stride

    @stride.setter
    def stride(self, stride: int):
        self._stride = stride
        self._num_strided_snippets = kdai.datasets.num_windows_2d(
            self._ends, self.snippet_len, self.stride
        )

    @property
    def num_strided_timesteps(self):
        return self._num_strided_timesteps


class NextSpikeDataset(torch.utils.data.Dataset):

    def __init__(
        self,
        recordings: List[SpikeRecording],
        input_len: int,
        stride: int = 1,
        channels_first: bool = True,
        augment: bool = False,
    ):
        self.channels_first = channels_first
        # Check that recording chunks sufficiently match.
        _sample_rates = [rec.sample_rate for rec in recordings]
        _stim_shape = [rec.stimulus.shape[1] for rec in recordings]

        def _all_same(items):
            return all(i == items[0] for i in items)

        if not _all_same(_sample_rates):
            raise ValueError(
                "All recordings must have the same sample rate."
                f" Got ({_sample_rates})"
            )
        if not _all_same(_stim_shape):
            raise ValueError(
                "All recordings must have the same stimulus shape."
                f" Got ({_stim_shape})"
            )

        self._sample_rate = _sample_rates[0]
        self._recordings = recordings
        self.input_len = input_len
        # We don't support padding/masking, so we require sequence lengths to
        # be greater than the input length + 1 (+1 for output prediction).
        # This assert allows us to use a mask of all 1s.
        if len(self._recordings[0]) < input_len + 1:
            raise ValueError(
                "The recordings must be longer than the input length + 1. "
                f"Got ({len(self._recordings[0])=} < {input_len=})"
            )
        self.ds = ConcatDataset(
            [
                SnippetΔtDataset(
                    rec,
                    input_len,
                    stride=stride,
                    shuffle_stride=True,
                    augment=augment,
                )
                for rec in self._recordings
            ]
        )


    def __len__(self):
        """
        Calculates the number of samples in the dataset.

        There will be one sample for every timestep in the recording.
        """
        return len(self.ds)

    @property
    def recordings(self):
        return self._recordings

    @property
    def datasets(self):
        return self.ds.datasets

    @property
    def sample_rate(self):
        return self._sample_rate

    @property
    def stride(self):
        return self.ds.stride

    @stride.setter
    def stride(self, stride: int):
        self.ds.stride = stride

    def __getitem__(self, idx):
        """
        Returns the (X,y) sample at the given index.

        Index is one-to-one with the timesteps in the recording.
        """
        sample = self.ds[idx]
        stim = sample["stimulus"].astype(np.float32)
        assert stim.ndim == 2
        spikes = sample["spikes"].astype(np.float32)
        seq_id = sample["cell_id"]
        assert type(seq_id) == int
        spikes = einops.rearrange(spikes, "t -> t 1")
        next_spike = sample["next_spike"].astype(np.float32)
        assert next_spike.ndim == 0, next_spike.shape
        time_dim = lambda x: x.shape[0]
        assert time_dim(stim) == time_dim(spikes)
        X = np.concatenate((stim, spikes), axis=1)
        y = next_spike
        if self.channels_first:
            X = einops.rearrange(X, "t c -> c t")
            X = np.ascontiguousarray(X)  # Useful?
        return X, seq_id, y


class BasicDistDataset(torch.utils.data.Dataset):
    """
    Dataset that pairs a stimulus+spike history with a future distance array.

    Future stimulus information is not made available.

    Output is a dictionary:

        "snippet": (5, input_len), float
            - Returned as is from the underlying SnippetDataset.
               - the first 4 channels are the stimulus, containing float values
                 in the range [0, 1].
               - the 5th channel contains spike history, a binary array of
                 {0, 1}.
        "dist": (output_len,), float
            - A distance array of length output_len, starting from
              time_bin=0 - dist_prefix_len. The values are *not* the log
              values. The values are clipped to dist_clamp.
        "target_spikes": (output_len,), float
            - Same as the 5th channel of the "snippet" array, but for the
              output_len bins starting from t=0.
        "cell_id": int
            - The cell ID of the cell that the snippet is for.

    No normalization
    ----------------
    No normalizing is done by the dataset; do it yourself in a trainable or
    model. Why?

        - changing norm will affect all previously trained models.
        - normalized spikes would no longer be binary {0, 1}, which is
            unintuitive for subsequent analysis.
        - stimulus norm is not the same for all recordings.
    """

    def __init__(
        self,
        recordings: List[SpikeRecording],
        input_len: int,
        output_len: int,
        pad: int,
        dist_prefix_len: int,
        dist_clamp: float,
        stride: int = 1,
        shuffle_stride: bool = False,
        use_augmentation: bool = False,
    ):
        if output_len < dist_prefix_len:
            raise ValueError(
                "The output must be longer than it's offset relative to t=0."
                f"Got output_len ({output_len}) < dist_prefix_len "
                f"({dist_prefix_len})."
            )
        self.input_len = input_len
        self.output_len = output_len
        self._recordings = recordings
        for r in recordings:
            assert r.stimulus.shape[1] == recordings[0].stimulus.shape[1]
            assert r.sample_rate == recordings[0].sample_rate
        self._sample_rate = self._recordings[0].sample_rate
        self.num_stim_channels = self._recordings[0].stimulus.shape[1]
        self.pad = pad
        self.dist_prefix_len = dist_prefix_len
        self.dist_clamp = dist_clamp
        self.use_augmentation = use_augmentation
        self.ds = ConcatDataset(
            [
                SnippetDataset(
                    rec,
                    self.input_len
                    + self.output_len
                    - self.dist_prefix_len
                    + self.pad,
                    stride,
                    shuffle_stride=shuffle_stride,
                )
                for rec in self._recordings
            ]
        )

    def __len__(self):
        return len(self.ds)

    @property
    def recordings(self):
        return self._recordings

    @property
    def datasets(self):
        return self.ds.datasets

    @property
    def sample_rate(self):
        return self._sample_rate

    @property
    def stride(self):
        return self.ds.stride

    @stride.setter
    def stride(self, stride: int):
        self.ds.stride = stride

    def output_spikes(self, cell_id=0) -> np.ndarray:
        """
        Returns the underlying 1D spike array that is equivalent to
        concatenating the spike portion of the dataset's outputs in ascending
        order. The equivalent concatenation must truncate the output to
        the stride length.

        This is useful for obtaining the full length ground truth spike array
        without having to iterate over the dataset. It is also useful to have
        this as a function so that it can be tested for any indexing issues—such
        issues could render evaluation metrics incorrect.

        If there are multiple recording chunks, then they will be concatenated
        together.
        """
        res = []
        start = self.input_len
        future_out_len = self.output_len - self.dist_prefix_len
        if self.stride > future_out_len:
            raise ValueError(
                "Stride must be less-equal the length of the future portion of "
                "the output in order to recover the unbroken spike sequence "
                f"via successive concatenations. Got stride ({self.stride}) vs."
                f" ({future_out_len})."
            )
        for ds in self.datasets:
            end = start + self.stride * len(ds)
            res.append(ds.recording.spikes[start:end, cell_id])
        res = np.concatenate(res)
        return res

    def _augment_stimulus(self, stimulus):
        """
        Augment a stimulus portion of a sample.
        """
        NOISE_SD = 0.2
        NOISE_MU = 0
        STIM_MASK_RATE = 0.01
        MIXUP_P = 0.2
        MASK3_VALUE = -3
        # Mixup
        mix_idx = np.random.randint(0, len(self))
        mixup = self.ds[mix_idx]["stimulus"][:, 0 : stimulus.shape[1]]
        stimulus = stimulus * (1 - MIXUP_P) + mixup * MIXUP_P
        # Whole block scale.
        mu = 1.0
        sd = 0.10
        scale = np.random.normal(mu, sd, size=(1,))
        # Whole block offset.
        mu = 0.0
        sigma = 0.10
        offset_noise = np.random.normal(mu, sigma, size=(1,))
        # Per bin noise.
        max_length = stimulus.shape[1]
        left, right = (0, max_length - 1)
        bin_noise = np.random.normal(
            NOISE_MU,
            NOISE_SD,
            size=(self.num_stim_channels, (right - left)),
        )
        stimulus = stimulus * scale + offset_noise
        stimulus[:, left:right] += bin_noise
        # Mask some parts.
        mask_indicies = np.nonzero(
            np.random.binomial(1, p=STIM_MASK_RATE, size=len(stimulus))
        )
        stimulus[:, mask_indicies] = MASK3_VALUE
        return stimulus

    def __getitem__(self, idx):
        """
        Returns the sample at the given index.

        +---------------------------+
        |  a) input stimulus        |
        +---------------------------+
        |  b) input spike           |
        +---------------------------+
                                        this gap is the output offset,
                              |<--->|   typically negative.

                              |-----+---------------+---------+
                              |   c) dist target*   | d) pad* |
                              |-----+---------------+---------+
                                    | e) out spikes |
                                    +---------------+


        Note (c*): the target distance array is for the time interval
        [0-dist_prefix_len, 0-dist_prefix_len+output_len). For example,
        if the prefix length is 32 and the output length is 128, this would be
        [-32, 96). While the interval extends into the past by the prefix
        length, only future spikes are considered to calculate the distance
        array. The alternative of including spikes that land in [-prefix_len, 0)
        would be fine too, and this approach has been tried, but the inference
        is harder. It's harder because spikes within [-prefix_len, 0) will fix
        an upper limit on the target distance array, and the model will always
        predict lower than this, making inference susceptible to incorrectly
        place spikes close to zero.

        Note (d*): there is an extra bit of spike data used when creating
        a sample, here called a pad. The pad is used to calculate the ground
        truth distance array. This bit of data is not placed in the sample that
        is returned.

        """
        sample = self.ds[idx]
        stimulus = sample["stimulus"][:, 0 : self.input_len]
        # Switch to float for spikes.
        spikes = np.array(sample["spikes"], copy=True, dtype=float)
        dist_input_spikes = np.concatenate(
            [
                np.zeros(self.dist_prefix_len, dtype=float),
                spikes[self.input_len :],
            ]
        )
        dist = sdf.distance_arr(dist_input_spikes, self.dist_clamp)[
            0 : self.output_len
        ]
        in_spikes = spikes[0 : self.input_len]
        # Target spikes only include the future spikes (t>0).
        # Only include output_len - dist_prefix_len spikes, as those are the
        # only _future_ spikes that we will attempt to predict. Some plotting
        # tools will error if we exceed this.
        target_spikes = spikes[
            self.input_len : self.input_len
            + self.output_len
            - self.dist_prefix_len
        ]
        if self.use_augmentation:
            stimulus = self._augment_stimulus(stimulus)
        # Returning a dictionary is more flexible than returning a tuple, as
        # we can add to the dictionary without breaking existing consumers of
        # the dataset.
        res = {
            "snippet": np.vstack((stimulus, in_spikes)),
            "dist": dist,
            "target_spikes": target_spikes,
            "cell_id": sample["cell_id"],
        }
        return res


class DistDatasets(kdai.train.DatasetManager):
    """
    Interface for (train, val, test) BasicDistDatasets dataset.
    """

    DIST_CLAMP_MS = 200
    LOSS_CALC_PAD_MS = 200

    def __init__(
        self,
        splits: Sequence[RecordingTrainValTest],
        input_len: int,
        output_len: int,
        downsample: int,
        dist_prefix_len: int,
        stride: int = 1,
        use_augmentation: bool = False,
    ):
        """
        Args:
            splits: multiple recordings worth of train, val, and test splits.
        """
        self._splits = splits
        tr, v, ts = zip(*splits)
        self.train_recs = tr
        self.val_recs = v
        self.test_recs = ts
        self._input_len = input_len
        self._output_len = output_len
        self._downsample = downsample
        self.dist_prefix_len = dist_prefix_len
        # Previously, this was varied to decrease epoch size and increase
        # checkpoint frequency, but we now have a better way: checkpointing by
        # steps.
        if stride != 1:
            _logger.warning(
                f"Using {stride=} (!= 1). Why not use checkpointing by step?"
            )
        self._train_stride = 1
        self._use_augmentation = use_augmentation
        # Make sure all recordings have the same sample rate.
        sample_rates = np.array(
            [
                rec.sample_rate
                for rec in itertools.chain(
                    itertools.chain(*self.train_recs),
                    itertools.chain(*self.val_recs),
                    itertools.chain(*self.test_recs),
                )
            ]
        )
        if not np.unique(sample_rates).size == 1:
            raise ValueError(
                "All recordings must have the same sample rate."
                f" Got ({sample_rates})"
            )
        self.sample_rate = sample_rates[0]

    def _to_ds(
        self,
        rec_parts: ContiguousChunks,
        stride: int,
        shuffle_stride: bool,
        use_augmentation: bool,
    ):
        res = BasicDistDataset(
            rec_parts,
            input_len=self._input_len,
            output_len=self._output_len,
            pad=round(ms_to_num_bins(self.LOSS_CALC_PAD_MS, self._downsample)),
            dist_prefix_len=self.dist_prefix_len,
            dist_clamp=round(
                ms_to_num_bins(self.DIST_CLAMP_MS, self._downsample)
            ),
            stride=stride,
            shuffle_stride=shuffle_stride,
            use_augmentation=use_augmentation,
        )
        return res

    @staticmethod
    def _single(recs: Sequence[ContiguousChunks]):
        # This manager currently only supports 1 recording.
        if len(recs) != 1:
            raise ValueError(
                "This manager only supports 1 recording." f"Got ({len(recs)})."
            )
        rec = recs[0]
        return rec

    def to_train_ds(self, recs: Sequence[ContiguousChunks]):
        rec_parts = self._single(recs)
        return self._to_ds(
            rec_parts, self._train_stride, True, self._use_augmentation
        )

    def to_val_ds(self, recs: Sequence[ContiguousChunks]):
        rec_parts = self._single(recs)
        return self._to_ds(rec_parts, 1, False, False)

    def to_test_ds(self, recs: Sequence[ContiguousChunks]):
        rec_parts = self._single(recs)
        return self._to_ds(rec_parts, 1, False, False)

    def train_ds(self) -> torch.utils.data.Dataset:
        return self.to_train_ds(self.train_recs)

    def val_ds(self) -> torch.utils.data.Dataset:
        return self.to_val_ds(self.val_recs)

    def test_ds(self) -> torch.utils.data.Dataset:
        return self.to_test_ds(self.test_recs)

    def single_cid_val_ds(self, cid: int) -> torch.utils.data.Dataset:
        assert (
            len(self.val_recs) == 1
        ), "This manager only supports 1 recording."
        rec_parts = [chunk.clusters({cid}) for chunk in self.val_recs[0]]
        return self._to_ds(rec_parts, 1, False, False)

    def dist_mean_sd(self) -> Tuple[float, float]:
        return self._dist_mean_sd()

    def log_dist_mean_sd(self) -> Tuple[float, float]:
        return self._dist_mean_sd(np.log)

    # TODO: test this function.
    def _dist_mean_sd(self, transform_fn=None) -> Tuple[float, float]:
        """Calculate μ and σ distance values for the training set."""
        # Training data comes in the form:
        #    [
        #      [ # per recording
        #        [ # per chunk
        #          [ -- cell 1 -- ],
        #          [ -- cell 2 -- ],
        #          ...
        #        ],
        #        ...
        #      ],
        #      ...
        #    ]
        # We will do an online mean and sd calculation, where we loop over
        # chunks, as this is as coarse we can get with same shape data.
        mean = 0
        x2 = 0
        N = sum(
            [
                chunk.spikes.size
                for rec in self.train_recs
                for chunk in rec
            ]
        )
        n_check = 0
        for r in self.train_recs:
            # Old, slower.
            # for chunk in r:
            #     for c_id in range(chunk.num_cells()):
            #         dist = sdf.distance_arr(
            #             chunk.spikes[:, c_id], self.DIST_CLAMP_MS
            #         )
            #         if transform_fn:
            #             dist = transform_fn(dist)
            #         mean += np.sum(dist) / N
            #         x2 += np.sum(dist**2) / N
            for chunk in r:
                dist = sdf.distance_arr_vec(chunk.spikes, self.DIST_CLAMP_MS)
                if transform_fn:
                    dist = transform_fn(dist)
                mean += np.sum(dist) / N
                x2 += np.sum(dist**2) / N
                n_check += chunk.spikes.size
        sd = np.sqrt(x2 - mean**2)
        assert N == n_check, f"{N=} != {n_check=}"
        return (mean, sd)


class ConcatDataset(torch.utils.data.Dataset):
    """
    Dataset that concatenates datasets and inserts a dataset label.

    This is an edited version of PyTorch's ConcatDataset, with the addition
    of a dataset index being included in each sample. This is useful for
    making a multi-recording dataset easily from a list of single recording
    datasets. The PyTorch implementation wasn't sufficient, as we want to
    include information of which recording a sample belongs to.

    As stride affects dataset sizes, we add the setting and getting of strides
    here.
    """

    datasets: List[torch.utils.data.Dataset]
    cumulative_sizes: List[int]

    @staticmethod
    def cumsum(sequence):
        r = []
        s = 0
        for e in sequence:
            l = len(e)
            r.append(l + s)
            s += l
        return r

    def __init__(
        self,
        datasets: Iterable[torch.utils.data.Dataset],
        label_key: Optional[str] = None,
    ) -> None:
        """
        Args:
            label_key: If not None, the dataset index will be included in each
                sample, under this key. If None, the dataset index will not be
                included. Example, label_key = "id". This was originally added
                to allow the recording id to be included in the sample when
                multiple recordings form a concatenated dataset. However,
                to allow for more flexibility in what recordings are loaded in
                train vs. test time, this responsibility has been delegated to
                the underlying per-recording dataset. We do however need to
                enable the recording identifier, as it is disabled by default.
                At the moment, this is expected to be done to each dataset
                before they are passed in here. With label_key = None, this
                class is equivalent to the PyTorch ConcatDataset. Might end
                up removing it if the label key isn't needed anymore.
        """
        super().__init__()
        self.datasets = list(datasets)
        for d in self.datasets:
            assert not isinstance(
                d, torch.utils.data.IterableDataset
            ), "ConcatDataset does not support IterableDataset"
        self.label_key = label_key
        assert len(self.datasets) > 0, (
            "datasets should not be an empty " "iterable"
        )
        self._update_sizes()

    def _update_sizes(self):
        self.cumulative_sizes = self.cumsum(self.datasets)

    def __len__(self):
        return self.cumulative_sizes[-1]

    @property
    def stride(self):
        res = self.datasets[0].stride
        assert all([ds.stride == res for ds in self.datasets])
        return res

    @stride.setter
    def stride(self, stride: int):
        for ds in self.datasets:
            ds.stride = stride
        self._update_sizes()

    def __getitem__(self, idx):
        if idx < 0:
            if -idx > len(self):
                raise ValueError(
                    "absolute value of index should not exceed dataset length"
                )
            idx = len(self) + idx
        dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
        if dataset_idx == 0:
            sample_idx = idx
        else:
            sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
        sample = self.datasets[dataset_idx][sample_idx]
        if self.label_key:
            assert isinstance(
                sample, dict
            ), "Sample must be a dictionary in order to add a label."
            sample[self.label_key] = dataset_idx
        return sample
