import logging
import math
from pathlib import Path
from collections.abc import Iterator
import re
from functools import partial
import torch
import json
import gc
import numpy as np
import polars as pl
import scipy.stats
import kdai
import kdai._logging
import kdai.train
import kdai.lrfind
import kdai.datasets
import kdtpp
import kdtpp.datasets as ds
import kdtpp.models as models
import kdtpp.trainables as trainables
import kdtpp.disttrainable
import kdtpp.inferspikes as inferspikes
import kdtpp.mea as mea
import multiprocessing
from typing import (
    Sequence,
    Literal,
    Callable,
    Tuple,
    TypeAlias,
    Optional,
    Set,
    Dict,
)
from tqdm import tqdm
from functools import reduce
import operator
from dataclasses import dataclass

_logger = logging.getLogger(__name__)


OUTPUT_ROOT_DIR = "./out/exp"
LOG_FILENAME = "train.log"
SPLIT_RATIO = (7, 2, 1)
# Maximum input length over all models. Used to sync the starting index of
# some of the datasets.
MAX_IN_LEN = 128


class IgnoreFilter(logging.Filter):
    def __init__(self, substrings: Set[str]):
        super().__init__()
        self.substrings = substrings

    def filter(self, record: logging.LogRecord) -> bool:
        """Filter out log records that contain any of the substrings."""
        for substring in self.substrings:
            if substring in record.getMessage():
                return False
        return True


def start_logging(labels, ignore_substrings=None):
    root_logger = kdai._logging.setup_logging(logging.INFO)
    if ignore_substrings is not None:
        for h in root_logger.handlers:
            h.addFilter(IgnoreFilter(set(ignore_substrings)))
    out_dir = kdai._logging.get_outdir(OUTPUT_ROOT_DIR, labels)
    kdai._logging.enable_file_logging(out_dir / LOG_FILENAME)
    kdai._logging.snapshot_importing_script(out_dir)
    kdai._logging.snapshot_module(out_dir, "kdtpp")
    kdai._logging.snapshot_module(out_dir, "kdai")
    logging.info(f"Output dir: {out_dir}")
    return out_dir


"""ModelMode is used for two things:
    1. To change how much detail to go into when evaluating. You want to 
       evaluate loss using many samples when the training job is being used to
       create a recordable model. In this case, you don't need detailed metrics.
       When investigating models, you may wish to trade off metric accuracy for
       metric detail (train-info). 
    2. The introduction of the exp head's gradient descent based initialization
       means that weight initialization can't be called by default in the
       trainable's constructor, as pre-trained models are often loaded in a
       no-grad context and the init can't run (and shouldn't really, as it's not
       needed). So, the init should be conditionally called. ModelMode will be
       used to decide whether to call the init. Currently, only OmiTrainable uses
       this feature.
"""
ModelMode = Literal[
    # Many samples, loss only.
    "train-loss",
    # Less samples, but more metrics. For investigating.
    "train-info",
    # More samples and more metrics. For when you want to report details about
    # training dynamics.
    "train-info2",
    # More samples and all metrics. For accurately evaluating models.
    # Typically used after training when the best models have been selected.
    "eval-metrics",
]


def eval_opts(mode: ModelMode):
    """Convert a ModelMode to parameters for Trainable.__init__"""
    if mode == "train-loss":
        opts = {"eval_mode": "loss", "eval_len": 2 ** (10 + 6)}
    elif mode == "train-info":
        opts = {"eval_mode": "info", "eval_len": 2 ** (10 + 4)}
    elif mode == "train-info2":
        opts = {"eval_mode": "info", "eval_len": 2 ** (10 + 6)}
    elif mode == "eval-metrics":
        opts = {
            "eval_mode": "info",
            "eval_len": 2 ** (10 + 7),
            "init_weights": False,
        }
    else:
        raise ValueError(f"Unknown mode: {mode}")
    return opts


def rnn_nn(ds_fn, model_mode):
    ds_mgr = ds_fn(model_in_len=32)
    model = models.OmiNN(models.BaseRNN(n_in=2, n_h=64), n_h=64, n_layers=2)
    res = trainables.OmiTrainable(
        ds_mgr,
        model,
        label="rnn-nn",
        hazard_type="nn",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    return res


def rnn_const(ds_fn, model_mode):
    ds_mgr = ds_fn(model_in_len=32)
    model = models.OmiConstant(models.BaseRNN(n_in=2, n_h=64))
    res = trainables.OmiTrainable(
        ds_mgr,
        model,
        label="rnn-const",
        hazard_type="const",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    return res


def rnn_exp(ds_fn, model_mode):
    ds_mgr = ds_fn(model_in_len=32)
    model = models.OmiExponential(models.BaseRNN(n_in=2, n_h=64))
    model = torch.compile(model)
    res = trainables.OmiTrainable(
        ds_mgr,
        model,
        label="rnn-exp",
        hazard_type="exp",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    res.forward = torch.compile(res.forward)
    return res


def rnn_logmix(ds_fn, model_mode, n_mix=64):
    ds_mgr = ds_fn(model_in_len=32)
    model = models.ShchurLogMix(n_h=64, n_mix=n_mix)
    cmodel = torch.compile(model)
    res = trainables.ShchurLogMixTrainable(
        ds_mgr,
        cmodel,
        label="rnn-logmix",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    res.forward = torch.compile(res.forward)
    return res


def rnn_cat(ds_fn, model_mode, out_resolution=128):
    ds_mgr = ds_fn(model_in_len=32)
    model = models.RnnCat(
        models.BaseRNN(n_in=2, n_h=64), out_resolution=out_resolution
    )
    model = torch.compile(model)
    res = trainables.VarBinTrainable(
        ds_mgr,
        model,
        label="rnn-cat",
        causal=False,
        **eval_opts(model_mode),
    )
    # Commented out as the lru_cache decorator gives dynamo warnings.
    # res.forward = torch.compile(res.forward)
    return res


def omi_nn(ds_fn, model_mode):
    ds_mgr = ds_fn(model_in_len=8)
    model = models.OmiNN(models.OmiRNN(n_h=10, n_unroll=8), n_h=64, n_layers=2)
    res = trainables.OmiTrainable(
        ds_mgr,
        model,
        label="omi-nn",
        hazard_type="nn",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    return res


def omi_const(ds_fn, model_mode):
    ds_mgr = ds_fn(model_in_len=8)
    model = models.OmiConstant(models.OmiRNN(n_h=10, n_unroll=8))
    res = trainables.OmiTrainable(
        ds_mgr,
        model,
        label="omi-const",
        hazard_type="const",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    return res


def omi_exp(ds_fn, model_mode):
    ds_mgr = ds_fn(model_in_len=8)
    model = models.OmiExponential(models.OmiRNN(n_h=10, n_unroll=8))
    model = torch.compile(model)
    res = trainables.OmiTrainable(
        ds_mgr,
        model,
        label="omi-exp",
        hazard_type="exp",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    res.forward = torch.compile(res.forward)
    return res


def shchur_logmix(ds_fn, model_mode, n_marks=1, n_mix=64):
    ds_mgr = ds_fn(model_in_len=8)
    model = models.ShchurLogMix(n_marks=n_marks, n_h=64, n_mix=n_mix)
    cmodel = torch.compile(model)
    res = trainables.ShchurLogMixTrainable(
        ds_mgr,
        cmodel,
        label="shchur-logmix",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    res.forward = torch.compile(res.forward)
    return res


HeadType = Literal["const", "exp", "nn", "logmix", "logmix64"]


def gpt_base(
    head_type: HeadType,
    n_layer,
    n_head,
    head_dim,
    ds_fn,
    model_mode,
):
    in_len = 128
    ds_mgr = ds_fn(model_in_len=in_len)
    stem = models.GPTv2Stem(
        input_len=in_len,
        n_layer=n_layer,
        n_head=n_head,
        head_dim=head_dim,
    )
    n_c = n_head * head_dim
    if head_type == "nn":
        # NN head has a different interface by taking time as a query.
        head = models.NNHazard(n_c, n_h=64, n_layers=2)
        model = models.GptNNHead(stem, head)
        # Tracking the following issue for when this can be enabled.
        #   https://github.com/pytorch/pytorch/issues/91469
        do_compile = False
    else:
        do_compile = True
        if head_type == "const":
            head = models.ConstHazard(n_c)
        elif head_type == "exp":
            head = models.ExpHazard(n_c)
        elif head_type == "logmix":
            head = models.LogMixHead(n_c, n_mix=16)
        elif head_type == "logmix64":
            head = models.LogMixHead(n_c, n_mix=64)
        else:
            raise ValueError(f"Unknown head type: {head_type}")
        model = models.GptWithHead(stem, head)

    if do_compile:
        model = torch.compile(model)

    label = f"gpt-{n_layer}-{n_head}-{head_dim}-{head_type}"
    if head_type in ["const", "exp", "nn"]:
        res = trainables.GptHazardTrainable(
            ds_mgr,
            model,
            label=label,
            hazard_type=head_type,
            **eval_opts(model_mode),
        )
    elif head_type in {"logmix", "logmix64"}:
        res = trainables.GptLogMix(
            ds_mgr, model, label=label, **eval_opts(model_mode)
        )
    else:
        raise ValueError(f"Unknown head type: {head_type}")
    if do_compile:
        res.forward = torch.compile(res.forward)

    return res


def zuo_thp(param_set_idx, ds_fn, model_mode, n_marks=1):
    # THP only has a causal mode.
    ds_mgr = ds_fn(model_in_len=128)
    if n_marks != 1:
        raise NotImplementedError("Only supports 1 mark.")
    model = models.ZuoTHP.from_param_set(param_set_idx)
    model = torch.compile(model)
    res = trainables.ZuoTHPTrainable(
        ds_mgr, model, label=f"zuo-thp-{param_set_idx}", **eval_opts(model_mode)
    )
    # We can't compile the softmax_hazard, so just do the loss_fn part which
    # covers the MSE.
    # res.cforward = torch.compile(res.cforward)
    res.loss_fn = torch.compile(res.loss_fn)
    return res


def rmtpp(ds_fn, model_mode, n_marks=1):
    ds_mgr = ds_fn(model_in_len=100)
    model = models.RMTPP(n_q=1, n_h=10)
    res = trainables.OmiTrainable(
        ds_mgr,
        model,
        label="rmtpp",
        hazard_type="rmtpp",
        use_log_input=False,
        **eval_opts(model_mode),
    )
    return res


def gpt(n_layer, n_head, head_dim, causal, fixed_bins, ds_fn, model_mode):
    in_len = 128
    ds_mgr = ds_fn(model_in_len=in_len)
    model = models.GPTv2(
        input_len=in_len,
        n_layer=n_layer,
        n_head=n_head,
        head_dim=head_dim,
        out_resolution=128,
        causal=causal,
    )
    label = f"gpt-{n_layer}-{n_head}-{head_dim}"
    if fixed_bins:
        label += "-f"
    if not causal:
        label += "-nc"
    cmodel = torch.compile(model)
    res = trainables.DiscreteTrainable(
        ds_mgr,
        cmodel,
        label=label,
        bin_mode="fixed" if fixed_bins else "auto",
        causal=causal,
        **eval_opts(model_mode),
    )
    res.cforward = torch.compile(res.cforward)
    return res


def gptvar(n_layer, n_head, head_dim, causal, n_bins, ds_fn, model_mode):
    in_len = 128
    ds_mgr = ds_fn(model_in_len=in_len)
    model = models.GPTv2(
        input_len=in_len,
        n_layer=n_layer,
        n_head=n_head,
        head_dim=head_dim,
        out_resolution=n_bins,
        causal=causal,
    )
    label = f"gptvar-{n_layer}-{n_head}-{head_dim}"
    cmodel = torch.compile(model)
    res = trainables.VarBinTrainable(
        ds_mgr,
        cmodel,
        label=label,
        causal=causal,
        **eval_opts(model_mode),
    )
    return res


def control_discrete(ds_fn, model_mode, n_marks=1):
    model_in_len = 128
    # Full_y not needed, but currently, the trainable assumes it.
    ds_mgr = ds_fn(model_in_len=model_in_len)
    model = models.NoContextDiscrete()
    cmodel = torch.compile(model)
    res = trainables.DiscreteTrainable(
        ds_mgr,
        cmodel,
        label="control-discrete",
        bin_mode="fixed",
        **eval_opts(model_mode),
    )
    return res


SpikesDsType: TypeAlias = Literal["next_event", "interval", "distance"]
SpikesDsMgrFn = Callable[[int, SpikesDsType], kdai.train.DatasetManager]
rid = 12  # chicken-2021-08-17
EVAL_CIDS = [(rid, 21), (rid, 22), (rid, 120)]
N_CELLS = 4203  # non-filtered. Filtered, there is 1611.


def tf_logmix(
    n_layer,
    n_head,
    head_dim,
    n_mix,
    expansion,
    ds_fn: SpikesDsMgrFn,
    model_mode,
):
    """Logmix on top of a transformer, for equi-spaced binned data."""
    model_in_len = 1024
    ds_mgr = ds_fn(model_in_len, "next_event")
    model = models.LogmixTf(
        models.TransformerBase(
            input_len=model_in_len,
            n_layer=n_layer,
            n_head=n_head,
            head_dim=head_dim,
            n_seqs=N_CELLS,
            expansion=expansion,
        ),
        n_mix=n_mix,
    )
    model = torch.compile(model)
    res = trainables.LogMixForSpikes(
        ds_mgr,
        model,
        label=f"tf-{n_layer}-{n_head}-{head_dim}-logmix{n_mix}",
        eval_rec_cids=EVAL_CIDS,
        **eval_opts(model_mode),
    )
    res.forward = torch.compile(res.forward)
    return res


def tf_dist(
    n_layer, n_head, head_dim, expansion, ds_fn: SpikesDsMgrFn, model_mode
):
    model_in_len = 1024
    ds_mgr = ds_fn(model_in_len, "distance")
    model = models.DistTf(
        models.TransformerBase(
            input_len=model_in_len,
            n_layer=n_layer,
            n_head=n_head,
            head_dim=head_dim,
            n_seqs=N_CELLS,
            expansion=expansion,
        )
    )
    cmodel = torch.compile(model)
    res = kdtpp.disttrainable.DistTrainable(
        ds_mgr,
        cmodel,
        label=f"tf-{n_layer}-{n_head}-{head_dim}-dist",
        eval_rec_cids=EVAL_CIDS,
        **eval_opts(model_mode),
    )
    res.forward = torch.compile(res.forward)
    return res


def tf_discrete(
    n_layer,
    n_head,
    head_dim,
    out_res,
    expansion,
    ds_fn: SpikesDsMgrFn,
    model_mode,
):
    model_in_len = 1024
    ds_mgr = ds_fn(model_in_len, "next_event")
    model = models.DiscreteTf(
        models.TransformerBase(
            input_len=model_in_len,
            n_layer=n_layer,
            n_head=n_head,
            head_dim=head_dim,
            n_seqs=N_CELLS,
            expansion=expansion,
        ),
        out_resolution=out_res,
    )
    cmodel = torch.compile(model)
    res = trainables.SpikesDiscrete(
        ds_mgr,
        cmodel,
        label=f"tf-{n_layer}-{n_head}-{head_dim}-discrete",
        eval_rec_cids=EVAL_CIDS,
        **eval_opts(model_mode),
    )
    res.forward = torch.compile(res.forward)
    return res


trainable_fns = {
    # As per original implementations.
    "omi-nn": omi_nn,
    "omi-const": omi_const,
    "omi-exp": omi_exp,
    "rmtpp": rmtpp,  # Not used. Too similar to omi-exp.
    "shchur-logmix": partial(shchur_logmix, n_mix=16),
    "zuo-thp-0": partial(zuo_thp, 0),
    "zuo-thp-1": partial(zuo_thp, 1),
    "zuo-thp-2": partial(zuo_thp, 2),  # Not used.
    # With identical RNN base.
    "rnn-const": partial(rnn_const),
    "rnn-exp": partial(rnn_exp),
    "rnn-logmix": partial(rnn_logmix),
    "rnn-nn": partial(rnn_nn),
    "rnn-cat": partial(rnn_cat),
    # Reference
    "control-discrete": control_discrete,
    # gpt-a and gpt-b
    "gptvar-2-4-16": partial(gptvar, 2, 4, 16, True, 128),
    "gptvar-6-4-32": partial(gptvar, 6, 4, 32, True, 128),
    "gpt-2-4-16-f": partial(gpt, 2, 4, 16, True, True),
    "gpt-6-4-32-f": partial(gpt, 6, 4, 32, True, True),
    "gpt-2-4-16-const": partial(gpt_base, "const", 2, 4, 16),
    "gpt-2-4-16-exp": partial(gpt_base, "exp", 2, 4, 16),
    "gpt-2-4-16-nn": partial(gpt_base, "nn", 2, 4, 16),
    "gpt-6-4-32-const": partial(gpt_base, "const", 6, 4, 32),
    "gpt-6-4-32-exp": partial(gpt_base, "exp", 6, 4, 32),
    "gpt-6-4-32-nn": partial(gpt_base, "nn", 6, 4, 32),
    "gpt-2-4-16-logmix": partial(gpt_base, "logmix", 2, 4, 16),
    "gpt-6-4-32-logmix": partial(gpt_base, "logmix", 6, 4, 32),
    # 64 mixtures. Not really used as it had similar to poorer scores.
    "shchur-logmix64": partial(shchur_logmix, n_mix=64),
    "gpt-2-4-16-logmix64": partial(gpt_base, "logmix64", 2, 4, 16),
    "gpt-6-4-32-logmix64": partial(gpt_base, "logmix64", 6, 4, 32),
    # For spike data. Call it tf as it's not gpt. But keep it as
    # model-spec-head format.
    "tf-2-4-16-logmix": partial(tf_logmix, 2, 4, 16, 16, 4),
    "tf-6-4-32-logmix": partial(tf_logmix, 6, 4, 32, 16, 4),
    "tf-2-4-16-logmix64": partial(tf_logmix, 2, 4, 16, 64, 4),
    "tf-6-4-32-logmix64": partial(tf_logmix, 6, 4, 32, 64, 4),
    "tf-2-4-16-dist": partial(tf_dist, 2, 4, 16, 4),
    "tf-6-4-32-dist": partial(tf_dist, 6, 4, 32, 4),
    "tf-2-4-16-discrete": partial(tf_discrete, 2, 4, 16, 81, 4),
    "tf-6-4-32-discrete": partial(tf_discrete, 6, 4, 32, 81, 4),
    # With expansion=2 instead of 4.
    "tf-2-4-16-E2-logmix": partial(tf_logmix, 2, 4, 16, 16, 2),
    "tf-6-4-32-E2-logmix": partial(tf_logmix, 6, 4, 32, 16, 2),
    "tf-2-4-16-E2-logmix64": partial(tf_logmix, 2, 4, 16, 64, 2),
    "tf-6-4-32-E2-logmix64": partial(tf_logmix, 6, 4, 32, 64, 2),
    "tf-2-4-16-E2-dist": partial(tf_dist, 2, 4, 16, 2),
    "tf-6-4-32-E2-dist": partial(tf_dist, 6, 4, 32, 2),
    "tf-2-4-16-E2-discrete": partial(tf_discrete, 2, 4, 16, 81, 2),
    "tf-6-4-32-E2-discrete": partial(tf_discrete, 6, 4, 32, 81, 2),
    # Larger.
    "tf-2-8-128-E2-logmix64": partial(tf_logmix, 2, 8, 128, 64, 2),
    "tf-2-8-128-E2-logmix": partial(tf_logmix, 2, 8, 128, 16, 2),
    "tf-2-8-128-E2-dist": partial(tf_dist, 2, 8, 128, 2),
    "tf-2-8-128-E2-discrete": partial(tf_discrete, 2, 8, 128, 81, 2),
    "tf-8-4-32-E2-logmix": partial(tf_logmix, 8, 4, 32, 16, 2),
    "tf-8-4-32-E2-logmix64": partial(tf_logmix, 8, 4, 32, 64, 2),
    "tf-8-4-32-E2-dist": partial(tf_dist, 8, 4, 32, 2),
    "tf-8-4-32-E2-discrete": partial(tf_discrete, 8, 4, 32, 81, 2),
    "tf-8-4-32-E2-logmix": partial(tf_logmix, 8, 4, 32, 16, 2),
    "tf-8-4-32-E2-logmix64": partial(tf_logmix, 8, 4, 32, 64, 2),
    "tf-8-4-32-E2-dist": partial(tf_dist, 8, 4, 32, 2),
    "tf-8-4-32-E2-discrete": partial(tf_discrete, 8, 4, 32, 81, 2),
    "tf-6-8-128-E2-logmix": partial(tf_logmix, 6, 8, 128, 16, 2),
    "tf-6-8-128-E2-logmix64": partial(tf_logmix, 6, 8, 128, 64, 2),
    "tf-6-8-128-E2-dist": partial(tf_dist, 6, 8, 128, 2),
    "tf-6-8-128-E2-discrete": partial(tf_discrete, 6, 8, 128, 81, 2),
    # Didn't get used:
    "gpt-8-4-32-f": partial(gpt, 8, 4, 32, True, True),
    "gpt-10-4-32-f": partial(gpt, 10, 4, 32, True, True),
    "gpt-2-4-64-f": partial(gpt, 2, 4, 64, True, True),
    "gpt-2-4-128-f": partial(gpt, 2, 4, 128, True, True),
    "gpt-6-4-64-f": partial(gpt, 6, 4, 64, True, True),
    "gpt-4-6-64-f": partial(gpt, 4, 6, 64, True, True),
    "gpt-4-8-64-f": partial(gpt, 4, 8, 64, True, True),
    "gpt-2-8-128-f": partial(gpt, 2, 8, 128, True, True),
    "gpt-6-4-72-f": partial(gpt, 6, 4, 72, True, True),
    "gpt-6-6-72-f": partial(gpt, 6, 6, 72, True, True),
}

# last priority
default_train_args = {
    "eval_batch_size": 2048,
    "evals_til_eval_train_ds": 16,
    "weight_decay": 1e-2,
    "n_workers": 0,
    "steps_til_log": 256,
    "pin_memory": False,
    # Fallback. Only gets used for debugging and testing runs.
    "lr": 1e-4,
}

# 3rd priority
train_args_by_ds = {
    # All configurations except those for spike prediction task use a learning
    # rate lookup file. For the spike prediction task, the chosen LRs are
    # the same for all models, and there is only 1 dataset, so we can just
    # save the LRs in the config here.
    # The LR was chosen based on the 0/118/3/1/0 lrcalc.
    "chicken": {
        "lr": 5e-4,
    },
    "chicken-full": {
        "lr": 5e-4,
        "eval_batch_size": 1024,
    },
    "pubg": {"lr": 5e-4},
    "reddit-askscience": {"lr": 5e-4},
    "reddit-politics": {"lr": 5e-4},
    "twitter": {"lr": 5e-4},
    "yelp-airport": {"lr": 5e-4},
    "yelp-mississauga": {"lr": 5e-4},
    "mooc": {"lr": 5e-4},
    "wikipedia": {"lr": 5e-4},
}

# 2nd priority
train_args_by_model = {
    "omi-nn": {
        # doesn't work with the gradient calculation.
        "fuse_adam": False,
    },
    # These models are bigger, and 2048 can sometimes cause OOM.
    "zuo-thp-0": {
        "eval_batch_size": 1024,
    },
    "zuo-thp-1": {
        "eval_batch_size": 1024,
    },
    "tf-6-8-128-E2-logmix": {
        "eval_batch_size": 512,
        "batch_size": 512,
        "lr": 1e-5,
    },
    "tf-6-8-128-E2-discrete": {
        "eval_batch_size": 512,
        "batch_size": 512,
        "lr": 1e-5,
    },
}

# 1st priority
train_args_by_ds_and_trainable = {}


def get_train_args(ds_name, trainable_name):
    args = default_train_args.copy()
    args.update(train_args_by_ds.get(ds_name, {}))
    args.update(train_args_by_model.get(trainable_name, {}))
    args.update(
        train_args_by_ds_and_trainable.get((ds_name, trainable_name), {})
    )
    return args


class MajorRunSpec:
    """This is the class that functions like train() and eval() work with.

    The extra abstraction this class introduces is annoying, but it saves a
    lot of repetition of training and evaluation code. It also standardizes
    naming conventions and output locations.

    When results are eventually written to dataframes, the properties of this
    class will become columns.
    """

    model_name: str
    """should be the same as the name of the trainable"""
    train_len: int
    """train_len is not 1-1 with number of samples in the dataset. For example,
    the cyclic group datasets are defined by the number of 1024 length sequences,
    not the total number of events."""

    def __init__(
        self,
        model_name,
        ds_category,
        ds_name,
        train_len,
        n_epochs,
        batch_size,
        steps_til_eval,
        **other_labels,
    ):
        self.model_name = model_name
        self.ds_category = ds_category
        self.ds_name = ds_name
        self.train_len = train_len
        self.other_labels = other_labels
        for k, v in other_labels.items():
            setattr(self, k, v)
        # Train options
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        self.steps_til_eval = steps_til_eval

    def ds_fullname(self):
        return f"{self.ds_name}-{self.train_len}"

    @property
    def model_dir(self):
        """Relative path to the model directory.

        Convention is to have the model name as the highest level, and the
        dataset full name as the lowest level, with the other labels in
        between.
        """
        mid_segments = [f"{k}{v}" for k, v in self.other_labels.items()]
        res = Path(self.model_name).joinpath(*mid_segments, self.ds_fullname())
        return res

    def to_dict(self):
        return {
            "model": self.model_name,
            "ds": self.ds_name,
            "ds_fullname": self.ds_fullname(),
            "train_len": self.train_len,
            "n_epochs": self.n_epochs,
            "batch_size": self.batch_size,
            "steps_til_eval": self.steps_til_eval,
            **self.other_labels,
        }


"""
Callable that returns a DatasetManager. The single parameter is the 
model_in_len needed by the model. Only exists to speed up training for many
similar configurations. This callable is needed to delay the loading of the
datasets until needed. It also allows the function to be a closure and share a
lot of the time-consuming disk-reads between many dataset managers that simply
differ by training set length. """
DsMgrFn = Callable[[int], kdai.train.DatasetManager]


class RandProc:
    """
    The most efficient way to iterate over:

        - base datasets
        - train lengths
        - models

    is in that order. This way, the full length data is loaded first,
    then the shorter one, then all models share this data. However, the
    model is the unit that is most likely to be extended or investigated
    in isolation (testing model changes, or running larger models on a larger
    machine), and because of this, it is natural to output dataframes per model.
    For this reason, the run scripts will typically iterate over models, and
    it makes sense for the iteration order to be:

        - models
        - base datasets
        - train lengths

    This makes model-base_ds a unit at which we can create things once for
    reuse.
    """

    max_train_len = 2**25
    ds_len = "33816576"  # previously, with 1/3/1 it was "33619968"
    default_train_lens = [2**n for n in reversed(range(10, 26))]

    rand_processes = {
        "stationary-poisson": ds.gen_poisson,
        "nonstationary-poisson": ds.gen_nonstationary_poisson,
        "stationary-renewal": ds.gen_stationary_renewal,
        "nonstationary-renewal": ds.gen_nonstationary_renewal,
        "self-correcting": ds.gen_self_correcting,
        "hawkes1": ds.gen_hawkes1,
        "hawkes2": ds.gen_hawkes2,
        "metropolis-lognorm": None,  # Very slow, so only from disk currently.
    }

    def default_omi_processes(self):
        """Returns a dict of the default random processes."""
        res = []
        for name, gen_fn in self.rand_processes.items():
            if name == "metropolis-lognorm":
                continue
            gen_fn = partial(gen_fn, rng=np.random.default_rng(seed=0))
            res.append((name, gen_fn))
        return res

    @classmethod
    def multi_run_omi_processes(cls, n_runs):
        """For when we want multiple runs with unique random seeds.

        The run number will be used as the RNG seed.
        """
        res = []
        for run in range(n_runs):
            for name, gen_fn in cls.rand_processes.items():
                if name == "metropolis-lognorm":
                    continue
                gen_fn = partial(gen_fn, rng=np.random.default_rng(seed=run))
                res.append((name, run, gen_fn))
        return res

    @staticmethod
    def args_to_outpath(model_name: str, ds_name: str, train_len: int):
        """Convert model name, n_dim, and n_train_seq to out_path."""
        out_path = Path(model_name) / f"{ds_name}-{train_len}"
        return out_path

    @classmethod
    def from_disk(cls, ds_name, n_events, data_dir: Optional[Path] = None):
        if data_dir is None:
            data_dir = Path("./data")
        data_dir = Path(data_dir)
        if ds_name == "metropolis-lognorm":
            p = data_dir / "metropolis/events.npy"
            ts = np.load(p)
            # Contains a leading zero, so no need to pad.
            dts = np.diff(ts)
            # Let's just use the target distribution (lognorm)
            log_probs = scipy.stats.lognorm.logpdf(dts, 1)
        else:
            # Check if the request is for a specific run.
            m = re.match(r"(.+?)-r(\d+)", ds_name)
            if m:
                ds_name = m.group(1)
                run = int(m.group(2))
            else:
                # Default just uses first run.
                run = 0
            base_path = data_dir / f"rand_process/r{run}"
            p = base_path / f"{ds_name}_{cls.ds_len}event.npz"
            arr = np.load(p)
            ts = arr["ts"][0:n_events]
            assert len(ts) == n_events, f"{ds_name=} {len(ts)=} {n_events=}"
            # Recorded without starting zero!
            dts = np.diff(np.pad(ts, (1, 0)))
            log_probs = arr["log_probs"][0:n_events]
        return dts, log_probs

    @classmethod
    def for_model_and_ds(
        cls,
        model_name: str,
        ds_name: str,
        train_lens=None,
        data_dir=None,
        total_samples=None,
    ) -> Iterator[tuple[MajorRunSpec, DsMgrFn]]:
        if train_lens is None:
            train_lens = cls.default_train_lens

        # Omi et al. (2019) (and Shchur et al. (2020) which refers to Omi.) use
        # 100,000 events, split into 80k/20k. Let's get nice low variance
        # estimates by going a bit bigger with val and test lengths.
        VAL_LEN = 100 * 2**10
        TEST_LEN = 100 * 2**10

        def ratio_and_n_samples(tl):
            ratio = np.array([tl, VAL_LEN, TEST_LEN]) + MAX_IN_LEN
            n_samples = sum(ratio)
            return ratio, n_samples

        _, epoch_samples = ratio_and_n_samples(cls.max_train_len)

        dt_seq, log_probs = cls.from_disk(ds_name, epoch_samples, data_dir)

        for tl in train_lens:
            # Total length allows for MAX_IN_LEN offset to each split. In other
            # words, we are not considering any cases where the model has only
            # a partial input (close to the start of the sequence). We are
            # lucky enough to be able to do this, as we are dealing with only
            # a single long sequence, and not many smaller sequence, so any
            # effects of predictions made close to the start of the sequence
            # can be ignored. We do this as it greatly simplifies the likelihood
            # calculation, and as a result, the loss calculation and forward
            # pass.
            ratio, n_samples = ratio_and_n_samples(tl)
            # Cut from the end so that val and test are always the same.
            dt_subset = dt_seq[-n_samples:]
            assert len(dt_subset) == n_samples
            log_probs_subset = log_probs[-n_samples:]

            def to_data_mgr(model_in_len):
                res = ds.RandProcessDatasets(
                    dt_subset,
                    log_probs_subset,
                    model_in_len=model_in_len,
                    ratio=ratio,
                    synced_start=MAX_IN_LEN,
                )
                assert (
                    len(res.train_data) == tl
                ), f"{len(res.train_data)} != {tl}"
                assert len(res.val_data) == VAL_LEN
                assert len(res.test_data) == TEST_LEN
                return res

            n_epochs, batch_size, max_steps_per_eval = cls.epoch_opts(
                tl, total_samples
            )
            yield (
                MajorRunSpec(
                    model_name,
                    "rand_process",
                    ds_name,
                    tl,
                    n_epochs,
                    batch_size,
                    max_steps_per_eval,
                ),
                to_data_mgr,
            )

    @classmethod
    def for_model(
        cls,
        model_name: str,
        train_lens=None,
        data_dir=None,
    ) -> Iterator[tuple[MajorRunSpec, DsMgrFn]]:
        """Iterator for every configuration to be run for the given model."""
        for ds_name in cls.rand_processes.keys():
            yield from cls.for_model_and_ds(
                model_name, ds_name, train_lens, data_dir
            )

    @classmethod
    def epoch_opts(cls, train_len, total_samples=None):
        if total_samples is None:
            total_samples = 2**27
        n_epochs, batch_size = ds.epoch_opts(
            train_len,
            # 4 epochs of 2**25
            total_samples=total_samples,
            max_epochs=512,
            max_batch_size=2048,
            min_steps_per_epoch=128,
        )
        max_steps_per_eval = 1024
        return (n_epochs, batch_size, max_steps_per_eval)


class Cyclic:
    """Entry point for cyclic datasets.

    Has two jobs. 1) record paths and labels. 2) generate an iterator of
    run specs for the cyclic datasets.
    """

    MIN_POW = 0
    MAX_POW = 15
    SEQ_LEN = 1024
    CYCLIC_SPLIT_RATIO = (8, 1, 1)
    SAMPLES_PER_SEQ = 1024
    DATA_DIR = Path(f"./data/cyclic/")
    DS_CATEGORY = "cyclic"
    DS_NAME_FMT = DS_CATEGORY + "-{n_dim}"
    default_n_train_seqs = [2**n for n in range(MIN_POW, MAX_POW + 1)]
    default_n_dims = range(1, 11)

    @staticmethod
    def args_from_outpath(out_path):
        """What configuration is this out_path for?

        Useful for when cluster jobs unexpectedly fail.
        """
        out_path = Path(out_path)
        # example: gpt-6-4-32-logmix/n_dim1/cyclic-1-64 -> (gpt-6-4-32-logmix, 1, 64)
        m = re.match(r"(.+?)/n_dim(\d+)/cyclic-(?:\d+)-(\d+)", str(out_path))
        if not m:
            raise ValueError(f"Invalid out_path: {out_path}")
        model_name = m.group(1)
        n_dim = int(m.group(2))
        n_train_seq = int(m.group(3))
        return model_name, n_dim, n_train_seq

    @staticmethod
    def args_to_outpath(model_name: str, n_dim: int, n_train_seq: int):
        """Convert model name, n_dim, and n_train_seq to out_path."""
        # example: gpt-6-4-32-logmix/n_dim1/cyclic-1-64
        out_path = (
            Path(model_name) / f"n_dim{n_dim}" / f"cyclic-{n_dim}-{n_train_seq}"
        )
        return out_path

    @classmethod
    def for_model(
        cls,
        model_name: str,
        n_dims=None,
        n_train_seqs=None,
        data_dir=None,
    ) -> Iterator[tuple[MajorRunSpec, DsMgrFn]]:
        """Iterator for every configuration to be run for the given model."""
        if n_dims is None:
            n_dims = cls.default_n_dims
        if n_train_seqs is None:
            n_train_seqs = cls.default_n_train_seqs
        if data_dir is None:
            data_dir = cls.DATA_DIR

        # Cyclic dataset has extra n_dim parameter.
        for n_dim in n_dims:
            path = data_dir / f"{n_dim}dim_1obj_1024event.npz"
            ts = np.load(path)["ts"]
            # For time deltas, prepend a 0 to each sequence.
            dts = np.diff(np.pad(ts, ((0, 0), (1, 0))))
            assert np.all(dts > 0)
            train_seqs, val_seqs, test_seqs = kdai.datasets.split(
                dts, cls.CYCLIC_SPLIT_RATIO
            )
            # for is plenty.
            n_val_seqs = len(val_seqs)
            val_seqs = val_seqs[0 : n_val_seqs // 2]

            # Start with the biggest dataset, as it's the most interesting.
            for n_train_seq in sorted(n_train_seqs, reverse=True):
                assert len(train_seqs) >= max(n_train_seqs)
                train_subset = train_seqs[0:n_train_seq]
                assert (
                    len(train_subset) == n_train_seq
                ), f"{len(train_subset)} != {n_train_seq}"

                def to_data_mgr(model_in_len):
                    res = ds.EventSeqArrDatasets(
                        train_subset,
                        val_seqs,
                        test_seqs,
                        model_in_len=model_in_len,
                        prepend_blanks=(False, False, False),
                    )
                    return res

                n_epochs, batch_size, max_steps_per_eval = cls.epoch_opts(
                    n_train_seq
                )
                run_spec = MajorRunSpec(
                    model_name,
                    cls.DS_CATEGORY,
                    cls.DS_NAME_FMT.format(n_dim=n_dim),
                    n_train_seq,
                    n_epochs,
                    batch_size,
                    max_steps_per_eval,
                    n_dim=n_dim,
                )
                yield (run_spec, to_data_mgr)

    @classmethod
    def epoch_opts(cls, n_train_seq):
        n_epochs, batch_size = ds.epoch_opts(
            n_train_seq * cls.SAMPLES_PER_SEQ,
            total_samples=2 ** (15 + 10 + 2),
            max_epochs=512,
            max_batch_size=2048,
            min_steps_per_epoch=128,
        )
        max_steps_per_eval = 1024
        return (n_epochs, batch_size, max_steps_per_eval)


class Classic:
    """Non synthetic datasets like taxi, StackOverflow badges, etc.

    Multiple non-uniform sequences of events. Non uniformity means we need
    some slicing and trimming to get event counts that line up with the
    power of 2 train lengths that we want to use across all runs.
    """

    # List of the datasets that should be trained on by default. Other datasets
    # are used for testing or debugging.
    # nyc-taxi-12h has similar results to nyc-taxi, but the 12h split is not
    # well motivated, so we will just keep to nyc-taxi.
    datasets = ["so-badges", "nyc-taxi"]  # "nyc-taxi-12h"
    _default_train_lens = [2**n for n in reversed(range(10, 26))]
    dbg_datasets = [
        "so-badges-mini",
        "so-badges-hours",
        "so-badges-hours-mini",
        "nyc-taxi-mini",
    ]
    other_datasets = [
        "nyc-taxi-12h",
    ]

    _data_paths = {
        "so-badges": "stackoverflow/2024_08_29/badges_dataset",
        "nyc-taxi": "nyc_taxi_2013/taxi_dataset/",
        "nyc-taxi-12h": "nyc_taxi_2013/taxi_12h_dataset/",
        # Used for testing.
        "so-badges-mini": "stackoverflow/2024_08_29/badges_mini_dataset",
        "nyc-taxi-mini": "nyc_taxi_2013/taxi_mini_dataset/",
    }

    @staticmethod
    def args_to_outpath(model_name: str, ds_name: str, train_len: int):
        """Convert model name, n_dim, and n_train_seq to out_path."""
        out_path = Path(model_name) / f"{ds_name}-{train_len}"
        return out_path

    @classmethod
    def data_path(cls, ds_name, data_dir):
        if ds_name in {"so-badges", "so-badges-hours"}:
            res = data_dir / cls._data_paths["so-badges"]
        elif ds_name in {"so-badges-mini", "so-badges-hours-mini"}:
            res = data_dir / cls._data_paths["so-badges-mini"]
        else:
            res = data_dir / cls._data_paths[ds_name]
        return res

    @classmethod
    def default_train_lens(cls, ds_name):
        if "mini" in ds_name:
            res = [2**15]
        else:
            res = cls._default_train_lens
        return res

    @classmethod
    def scale(cls, ds_name):
        if ds_name in {"so-badges", "so-badges-mini"}:
            res = (1 / (60 * 60 * 24), "days")  # seconds to days
        elif ds_name in {"so-badges-hours", "so-badges-hours-mini"}:
            res = (1 / (60 * 60), "hours")  # seconds to hours
        elif ds_name in {"nyc-taxi", "nyc-taxi-mini", "nyc-taxi-12h"}:
            res = (1 / 60, "minutes")  # seconds to minutes
        elif ds_name in {"nyc-taxi-hours", "nyc-taxi-hours-mini"}:
            res = (1 / (60 * 60), "hours")  # seconds to hours
        else:
            raise ValueError(f"Unknown dataset: {ds_name}")
        return res

    @classmethod
    def load_seqs(cls, ds_name, data_dir=None):
        if data_dir is None:
            data_dir = "./data"
        data_dir = Path(data_dir)
        ds_path = cls.data_path(ds_name, data_dir)

        def _load(path):
            with open(ds_path / path) as f:
                event_times = json.load(f)
                # We cannot pad with zeros, as each individual sequence doesn't
                # have a known start time, unlike the synthetic datasets, where
                # we know the start time is 0.
                # Don't include sequences with only one event, as we can't
                # calculate a time delta.
                # Use density_interval_len instead.
                scale_by, _ = cls.scale(ds_name)
                time_deltas = [
                    np.diff(np.array(ts)) * scale_by
                    for ts in event_times
                    if len(ts) > 1
                ]
                assert all(
                    np.all(dt > 0) for dt in time_deltas
                ), "All time deltas should be positive."
                return time_deltas

        _logger.info(f'Loading sequences for dataset "{ds_name}".')
        train_dts = _load("train.json")
        _logger.info(f"Loaded {len(train_dts):,} train sequences.")
        val_dts = _load("val.json")
        _logger.info(f"Loaded {len(val_dts):,} val sequences.")
        test_dts = _load("test.json")
        _logger.info(f"Loaded {len(test_dts):,} test sequences.")
        return train_dts, val_dts, test_dts

    @classmethod
    def for_model_and_ds(
        cls,
        model_name: str,
        ds_name: str,
        train_lens=None,
        data_dir=None,
    ) -> Iterator[tuple[MajorRunSpec, DsMgrFn]]:
        train_dts, val_dts, test_dts = cls.load_seqs(ds_name, data_dir)
        assert type(train_dts) == list and type(train_dts[0]) == np.ndarray

        if train_lens is None:
            train_lens = cls.default_train_lens(ds_name)
        n_train_dts = sum(len(a) for a in train_dts)
        assert n_train_dts >= max(train_lens)
        for n_dt in train_lens:
            train_subset = ds.trim_to(train_dts, n_dt)
            cum_lens = sum([len(a) for a in train_subset])
            assert cum_lens >= n_dt, f"{cum_lens=} < {n_dt=}"
            assert cum_lens - len(train_subset[-1]) < n_dt, (
                "Too many events. "
                f"{cum_lens=} - {len(train_subset[-1])} >= {n_dt=}"
            )

            # Note, the idea about evaluating on intervals for the real-world
            # datasets didn't make it into the paper, other than just as a
            # suggestion, so this doesn't actually get used.
            #
            # This is a proposed option. All models should be evaluated on the
            # same interval length for calculating probability masses. This
            # doesn't need to be the same as the unit used in the dataset
            # event times (e.g. seconds, hours), and it doesn't have to be,
            # and most likely shouldn't be, the same unit used by the model,
            # which should apply it's own scaling for the purposes of good
            # training dynamics. Set this interval to 1 if the evaluation
            # interval is the same as the dataset event times.
            density_interval_len = 1

            def to_data_mgr(model_in_len):
                res = ds.EventSeqListDatasets(
                    train_subset,
                    val_dts,
                    test_dts,
                    model_in_len,
                    density_interval_len=density_interval_len,
                )
                n_samples = sum(len(a) - 1 for a in train_subset)
                assert (
                    len(res.train_ds()) == n_samples
                ), f"{len(res.train_ds())} != {n_samples}"
                return res

            yield (
                MajorRunSpec(
                    model_name,
                    "classic",
                    ds_name,
                    n_dt,
                    *cls.epoch_opts(n_dt),
                ),
                to_data_mgr,
            )

    @classmethod
    def for_model(
        cls,
        model_name: str,
        train_lens=None,
        data_dir=None,
    ) -> Iterator[tuple[MajorRunSpec, DsMgrFn]]:
        for ds_name in cls.datasets:
            yield from cls.for_model_and_ds(
                model_name, ds_name, train_lens, data_dir
            )

    @classmethod
    def epoch_opts(cls, train_len):
        n_epochs, batch_size = ds.epoch_opts(
            train_len,
            # 4 epochs of 2**25
            total_samples=2**27,
            max_epochs=512,
            max_batch_size=2048,
            min_steps_per_epoch=128,
        )
        max_steps_per_eval = 1024
        return (n_epochs, batch_size, max_steps_per_eval)


class Spikes:
    """Chicken RGC spike data recorded on a single day."""

    _data_path = "chicken_2021_08_17/dataset.json"
    split_ratio = (7, 2, 1)
    ds_name = "chicken"
    # fmt: off
    default_cids = [20, 21, 22, 23, 25, 31, 34, 35, 36, 37, 40, 41, 42, 44, 46, 49, 51, 54, 56, 61, 62, 65, 70, 72, 73, 74, 77, 78, 83, 84, 88, 90, 91, 102, 105, 112, 114, 120, 121, 128, 137, 138, 157, 174, 179, 185, 190, 202, 235, 238, 243, 246, 250, 280, 290, 297, 342, 403, 516, 760]
    # fmt: on
    downsample = 18
    stride = 1

    @staticmethod
    def dt_mean_sd(train_chunks: mea.ContiguousChunks):
        """
        Calculate the mean and standard deviation of time until next event.

        "Average" is over all snippets, so we are not just averaging over
        times between events. Each spike pair corresponds to an arange(1, b-a)
        of time until next event.
        """
        dts = np.concatenate(
            [rec.time_until_spike() for rec in train_chunks], axis=0
        )
        # Ignore negative values.
        dts = dts[dts > 0]
        res = dts.mean(), dts.std(), np.log(dts).mean(), np.log(dts).std()
        return res

    @staticmethod
    def entropy(train_chunks, win_len):
        dts = np.concatenate(
            [rec.time_until_spike() for rec in train_chunks], axis=0
        )
        # Ignore negative values.
        dts = dts[dts > 0]
        dts = np.clip(dts, 0, win_len)
        assert dts.ndim == 1
        assert np.array_equal(dts.astype(int), dts)
        hist = np.bincount(dts.astype(int))
        prob = hist / len(dts)
        entropy = -np.sum(prob * np.log(prob + 1e-10))
        return entropy

    @classmethod
    def load_split(
        cls,
        cell_ids: Set[int],
        data_dir: Optional[Path] = None,
    ):
        if data_dir is None:
            data_dir = "./data"
        data_dir = Path(data_dir)
        with open(data_dir / cls._data_path) as f:
            ds_details = json.load(f)
        rec = mea.CompressedSpikeRecording.from_json(ds_details)
        rec = rec.cells(cell_ids)
        rec = mea.decompress_recording(rec, downsample=cls.downsample)
        train_val_test_split = mea.mirror_split2(
            rec, split_ratio=cls.split_ratio
        )
        return train_val_test_split

    @classmethod
    def for_model(
        cls,
        model_name: str,
        train_cids: Optional[Sequence[int]] = None,
        data_dir: Optional[Path] = None,
    ) -> Iterator[tuple[MajorRunSpec, SpikesDsMgrFn]]:
        if train_cids is None:
            train_cids = cls.default_cids
        assert train_cids is not None
        split = cls.load_split(set(train_cids), data_dir)

        train, val, test = split
        assert train[0].num_cells() == len(train_cids)
        # Not exact, as we don't account for model_in_len.
        train_len = round(
            len(train_cids) * sum(len(a) for a in train) / cls.stride
        )

        def to_data_mgr(model_in_len, type: SpikesDsType):
            if type == "next_event":
                dt_mean, dt_sd, log_dt_mean, log_dt_sd = cls.dt_mean_sd(train)
                res = kdai.train.BasicDatasetManager(
                    mea.NextSpikeDataset(
                        train, model_in_len, cls.stride, augment=True
                    ),
                    mea.NextSpikeDataset(
                        val, model_in_len, cls.stride, augment=False
                    ),
                    mea.NextSpikeDataset(
                        test, model_in_len, cls.stride, augment=False
                    ),
                    train_ds_attrs={
                        "dt_mean": dt_mean,
                        "dt_sd": dt_sd,
                        "log_dt_mean": log_dt_mean,
                        "log_dt_sd": log_dt_sd,
                    },
                )
            elif type == "distance":
                splits = [split]  # supports multiple recordings
                output_len = 128
                dist_prefix_len = 32
                res = mea.DistDatasets(
                    splits,
                    model_in_len,
                    output_len,
                    cls.downsample,
                    dist_prefix_len,
                    stride=cls.stride,
                    use_augmentation=True,
                )

            elif type == "interval":
                raise NotImplementedError()
            return res

        yield (
            MajorRunSpec(
                model_name,
                "spikes",
                cls.ds_name,
                train_len,
                # Hard-coded for all, as currently, there is just one train len.
                n_epochs=2,
                batch_size=64,
                steps_til_eval=5000,
            ),
            to_data_mgr,
        )


@dataclass
class SpikeStats:
    """Some basic stats that are calculated and stored on a spike dataset."""

    dt_mean: float
    dt_sd: float
    log_dt_mean: float
    log_dt_sd: float
    t_until_mean: float
    t_until_sd: float
    """diffs on spikes separated by N.
    cell_gid → n → val    (where n determines arr[k:] - arr[:-k])
    gid first allows for easy dictionary merging across recordings.
    """
    cell_dNt_mean: Dict[int, Dict[int, float]]
    cell_dNt_min: Dict[int, Dict[int, float]]
    cell_dNt_max: Dict[int, Dict[int, float]]
    """max rates in window W.
    cell_gid → W → val
    """
    cell_max_rate: Dict[int, Dict[int, float]]


def per_rec_stats(
    rec, split_ratio, downsample, max_diff_n=5
) -> tuple[float, float, float, float, Dict, Dict, Dict, Dict]:
    train_hfrac = (split_ratio[0] / np.sum(split_ratio)) / 2
    C = rec.num_cells()
    train_part1 = [
        (
            c[c < np.floor(train_hfrac * rec.num_sensor_samples / downsample)]
            / downsample
        )
        for c in rec.spike_events
    ]
    train_part2 = [
        (
            c[
                c
                > math.ceil(
                    (1 - train_hfrac) * rec.num_sensor_samples / downsample
                )
            ]
            / downsample
        )
        for c in rec.spike_events
    ]
    parts = [train_part1, train_part2]
    # n → cell_idx → delta(part1, n) + delta(part2, n)

    # c_max_rate
    ts = [np.concatenate([p[c] for p in parts]) for c in range(C)]

    def max_spikes_in_win(dts, win_len):
        res = 0
        j = 0
        for i in range(len(dts)):
            while j < len(dts) and dts[j] - dts[i] <= win_len:
                j += 1
            res = max(res, j - i)
        return res

    Ws = [50, 100, 150, 200]
    c_max_rate = {
        rec.cell_gids[c]: {w: max_spikes_in_win(ts[c], w) for w in Ws}
        for c in range(C)
    }

    # time differences
    dts = {
        n: [
            np.concatenate([p[c][n:] - p[c][:-n] for p in parts])
            for c in range(C)
        ]
        for n in range(1, max_diff_n + 1)
    }

    # do_op is a generalization of:
    # mean_dts = {
    #     n: [np.mean(arr) for arr in arrs] for n, arrs in dts.items()
    # }
    def do_op(d, op):
        """Apply op and convert to cell_gid dict."""
        by_n = {
            n: {rec.cell_gids[i]: op(arrs[i]) for i in range(len(arrs))}
            for n, arrs in d.items()
        }
        by_gid = {
            gid: {n: by_n[n][gid] for n in by_n.keys()} for gid in rec.cell_gids
        }
        return by_gid

    # Per cell.
    c_mean_dNts = do_op(dts, lambda x: np.round(np.mean(x)))
    c_min_dNts = do_op(dts, lambda x: np.ceil(np.min(x)))
    c_max_dNts = do_op(dts, lambda x: np.round(np.max(x)))
    all_d1ts = np.concatenate(dts[1])
    assert np.all(
        all_d1ts > 0
    ), f"All time deltas should be positive. {rec.name=}"
    mean = all_d1ts.mean()
    alt_mean = np.concatenate(
        [np.diff(s) for s in train_part1 + train_part2]
    ).mean()
    assert np.isclose(
        mean,
        alt_mean,
        rtol=1e-5,
        atol=1,
    ), f"Means should be the same. {mean=} {alt_mean=}"
    log_mean = np.log(all_d1ts).mean()
    # Time until next event is the decompressed spike times array. We are
    # looking for the expected value, where expectation is over all time points.
    # This is very different to the mean time between events. Large gaps will
    # be weighted more heavily, as they are present for more time points.
    # We will take a stab at estimating it here, but won't calculate it exactly
    # as it needs decompressing all recordings.
    # integrate then average
    t_untils = np.concatenate([np.arange(1, d + 1) for d in all_d1ts])
    # _t_until = ((all_d1ts+1) / 2) * all_d1ts
    # t_until_mean = _t_until / all_d1ts.sum()
    t_until_mean = t_untils.mean()
    # t_until_sd = np.sqrt((all_d1ts - t_until_mean) ** 2).mean()
    t_until_sd = np.sqrt(((t_untils - t_until_mean) ** 2).mean())

    return (
        mean,
        log_mean,
        t_until_mean,
        t_until_sd,
        c_mean_dNts,
        c_min_dNts,
        c_max_dNts,
        c_max_rate,
    )


def per_rec_var(rec, split_ratio, downsample, mean, log_mean):
    test_frac = split_ratio[2] / np.sum(split_ratio)
    part1 = np.concatenate(
        [
            np.diff(s[s < math.floor((1 - test_frac) * rec.num_sensor_samples)])
            for s in rec.spike_events
        ]
    )
    part2 = np.concatenate(
        [
            np.diff(
                s[s >= math.floor((1 - test_frac) * rec.num_sensor_samples)]
            )
            for s in rec.spike_events
        ]
    )
    dts = np.concatenate([part1, part2]) / downsample
    assert np.all(dts > 0), f"All time deltas should be positive. {rec.name=}"
    var = ((dts - mean) ** 2).mean()
    log_var = ((np.log(dts) - log_mean) ** 2).mean()
    return (var, log_var)


class Spikes2:
    """Chicken RGC spike data recorded on multiple days.

    "Spikes" class is just one recording.
    """

    _data_path = "chicken_2021/dataset.json"
    split_ratio = (7, 2, 1)
    ds_name = "chicken-full"
    downsample = 18
    stride = 7
    n_samples = 1006897845

    @classmethod
    def stats(cls, recs: Sequence[mea.CompressedSpikeRecording]) -> SpikeStats:
        """
        Calculate mean(dts), mean(log(dts)) and other stats.

        It's too slow to work with SpikeRecordings, so do it with the
        compressed versions.

        We avoid using the test set for this calculation.
        """
        _logger.info(
            f"Calculating mean and variance for {len(recs)} recordings."
        )
        with multiprocessing.Pool() as pool:
            res = pool.map(
                partial(
                    per_rec_stats,
                    split_ratio=cls.split_ratio,
                    downsample=cls.downsample,
                ),
                recs,
            )
        (
            means,
            log_means,
            t_until_mean,
            t_until_sd,
            c_mean,
            _min,
            c_max,
            c_max_rate,
        ) = zip(*res)
        mean = np.mean(means)  # assumes roughly equal sizes.
        log_mean = np.mean(log_means)
        assert mean >= 0, "Mean should be non-negative."
        assert log_mean >= 0, "Log mean should be non-negative."
        # Merge cell dicts.
        c_mean = reduce(operator.ior, c_mean, {})
        c_min = reduce(operator.ior, _min, {})
        c_max = reduce(operator.ior, c_max, {})
        c_max_rate = reduce(operator.ior, c_max_rate, {})
        with multiprocessing.Pool() as pool:
            res = pool.map(
                partial(
                    per_rec_var,
                    split_ratio=cls.split_ratio,
                    downsample=cls.downsample,
                    mean=mean,
                    log_mean=log_mean,
                ),
                recs,
            )
        vars, log_vars = zip(*res)
        sd = np.sqrt(np.mean(vars))
        log_sd = np.sqrt(np.mean(log_vars))
        stats = SpikeStats(
            mean,
            sd,
            log_mean,
            log_sd,
            t_until_mean,
            t_until_sd,
            c_mean,
            c_min,
            c_max,
            c_max_rate,
        )
        return stats

    @staticmethod
    def entropy(train_chunks, win_len):
        dts = np.concatenate(
            [rec.time_until_spike() for rec in train_chunks], axis=0
        )
        # Ignore negative values.
        dts = dts[dts > 0]
        dts = np.clip(dts, 0, win_len)
        assert dts.ndim == 1
        assert np.array_equal(dts.astype(int), dts)
        hist = np.bincount(dts.astype(int))
        prob = hist / len(dts)
        entropy = -np.sum(prob * np.log(prob + 1e-10))
        return entropy

    @classmethod
    def load_recs(cls, data_dir: Optional[Path] = None):
        if data_dir is None:
            data_dir = "./data"
        data_dir = Path(data_dir)
        with open(data_dir / cls._data_path) as f:
            ds_details = json.load(f)
        gid_map = ds_details["recording_cell_ids"]
        recs = []
        for r in ds_details["recordings"]:
            # We don't need the cid-gid distinction, so just use gids.
            # json only allows string keys, so convert to str to look up gid.
            gids = [gid_map[r["name"]][str(cid)] for cid in r["cell_ids"]]
            r["cell_ids"] = gids
            recs.append(mea.CompressedSpikeRecording.from_json(r))
        return recs

    @classmethod
    def splits(cls, recs) -> mea.RecordingTrainValTest:
        # rec_splits has 1 entry per recording. The entry is a Tuple,
        # with the (train, val, test) contributions from each recording.
        # Each contribution is a list of ContiguousChunks. This seemingly
        # excessive structure is used to support a single recording having
        # more than one chunk that goes to the train set. In turn, this
        # was motivated by wanting the test set to be in the middle, as
        # the end of the recording can be where the cell is dying.
        rec_splits = mea.recording_splits(
            recs,
            downsample=cls.downsample,
            split_ratio=cls.split_ratio,
            # 16 workers, as there are 16 recordings.
            num_workers=16,
        )
        train_chunks, val_chunks, test_chunks = [], [], []
        for train, val, test in rec_splits:
            train_chunks.extend(train)
            val_chunks.extend(val)
            test_chunks.extend(test)
        # Alternative, with a generator:
        # train_chunks, val_chunks, test_chunks = (
        # list(itertools.chain.from_iterable(group))
        #     for group in zip(*rec_splits)
        # )
        # So, we still have a tuple of list, but the lists are longer and
        # are constituted from multiple recordings.
        return (train_chunks, val_chunks, test_chunks)

    @classmethod
    def for_model(
        cls,
        model_name: str,
        split: mea.RecordingTrainValTest,
        stats: SpikeStats,
    ) -> Iterator[tuple[MajorRunSpec, SpikesDsMgrFn]]:
        train, val, test = split
        train_cids = []
        n_cids = 0
        # train_len = 0
        # for rec in train:
        #     n_cids += len(rec.cell_gids)
        #     train_cids.extend(rec.cell_gids)
        #     # Not exact, as we don't account for model_in_len.
        #     train_len += len(rec.cell_gids) * len(rec)
        # train_len = round(train_len)
        train_len = cls.n_samples
        assert n_cids == len(train_cids), "Cell IDs (global) should be unique."

        def to_data_mgr(model_in_len, type: SpikesDsType):
            if type == "next_event":
                _logger.info(
                    f"Using dt_mean={stats.dt_mean:.3e}, "
                    f"dt_sd={stats.dt_sd:.3e}, "
                    f"log_dt_mean={stats.log_dt_mean:.3e}, "
                    f"log_dt_sd={stats.log_dt_sd:.3e}"
                )
                _logger.info(f"[Start] Creating datasets.")
                res = kdai.train.BasicDatasetManager(
                    mea.NextSpikeDataset(
                        train, model_in_len, cls.stride, augment=True
                    ),
                    mea.NextSpikeDataset(
                        val, model_in_len, cls.stride, augment=False
                    ),
                    mea.NextSpikeDataset(
                        test, model_in_len, cls.stride, augment=False
                    ),
                    train_ds_attrs={
                        "dt_mean": stats.dt_mean,
                        "dt_sd": stats.dt_sd,
                        "log_dt_mean": stats.log_dt_mean,
                        "log_dt_sd": stats.log_dt_sd,
                        "cell_dt_mean": stats.cell_dNt_mean,
                        "cell_dt_min": stats.cell_dNt_min,
                        "cell_dt_max": stats.cell_dNt_max,
                        "cell_max_rate": stats.cell_max_rate,
                    },
                )
            elif type == "distance":
                splits = [split]  # supports multiple recordings
                output_len = 128
                dist_prefix_len = 32
                _logger.info(f"[Start] Creating dist datasets.")
                res = mea.DistDatasets(
                    splits,
                    model_in_len,
                    output_len,
                    cls.downsample,
                    dist_prefix_len,
                    stride=cls.stride,
                    use_augmentation=True,
                )
                _logger.info(f"[End] Creating dist datasets.")

            elif type == "interval":
                raise NotImplementedError()
            return res

        yield (
            MajorRunSpec(
                model_name,
                "spikes2",
                cls.ds_name,
                train_len,
                # Hard-coded for all, as currently, there is just one train len.
                n_epochs=2,
                batch_size=1024,
                steps_til_eval=5000,
            ),
            to_data_mgr,
        )


class Baseline:
    """
    add-and-thin datasets are dict's with keys:
      'sequences', 't_max', 'mean_number_items'
    """

    # List of the datasets that should be trained on by default. Other datasets
    # are used for testing or debugging.
    # nyc-taxi-12h has similar results to nyc-taxi, but the 12h split is not
    # well motivated, so we will just keep to nyc-taxi.
    datasets = (
        "wikipedia",
        "mooc",
        "pubg",
        "reddit-askscience",
        "reddit-politics",
        "twitter",
        "yelp-airport",
        "yelp-mississauga",
        "yelp-toronto",
        "lastfm",
        "taobao",
        "amazon",
        # Just use the longer ones in Classic.
        # "nyc-taxi-easytpp",
        # "so-badges-easytpp",
    )
    # These may be different if some of the sequences have a single event,
    # which will be ignored when loading.
    n_dts = {
        "pubg": 226702,
        "reddit-askscience": 399579,
        "reddit-politics": 1234048,
        "twitter": 28058,
        "yelp-airport": 9398,
        "yelp-mississauga": 17304,
        "yelp-toronto": 214846,
        "mooc": 389586,
        "wikipedia": 156471,
        "lastfm": 1267456,
        # train, validation, test
        "taobao": 73483 + 11472 + 28455,  # 113410
        # train, validation, test
        "amazon": 288377 + 40995 + 84048,  # 411320
        "nyc-taxi-easytpp": 51854 + 7404 + 14820,  # 74078,
        # This dataset had some duplicates or sequences with just 2 events.
        # After removing those, we have:
        # train, validation, test = 90039 + 25632 + 26391 = 142062
        "so-badges-easytpp": 90497 + 25762 + 26518,  # 142777
    }
    tvt_ratio = (6, 2, 2)
    # Politics is the longest with 1234048 ~ 1 million events.
    # 2**27 samples would correspond to 128 epochs.

    _data_paths = {
        # minutes between 0-40.
        "pubg": "baseline/pubg.pkl",
        "reddit-askscience": "baseline/reddit_askscience_comments.pkl",
        # hours [0, 24]
        "reddit-politics": "baseline/reddit_politics_submissions.pkl",
        "twitter": "baseline/twitter.pkl",
        "yelp-airport": "baseline/yelp_airport.pkl",
        "yelp-mississauga": "baseline/yelp_mississauga.pkl",
        "yelp-toronto": "baseline/yelp_toronto.npz",
        "mooc": "baseline/mooc.npz",
        "wikipedia": "baseline/wikipedia.npz",
        "lastfm": "baseline/lastfm.pkl",  # or npz
        "taobao": "baseline/easytpp/taobao",
        "amazon": "baseline/easytpp/amazon",
        "nyc-taxi-easytpp": "baseline/easytpp/nyc-taxi",
        "so-badges-easytpp": "baseline/easytpp/stackoverflow",
    }

    @classmethod
    def data_path(cls, ds_name, data_dir):
        res = data_dir / cls._data_paths[ds_name]
        return res

    @classmethod
    def default_train_lens(cls, ds_name):
        raise NotImplementedError()

    @classmethod
    def scale(cls, ds_name):
        if ds_name in {"so-badges-mini"}:
            res = (1 / (60 * 60 * 24), "days")  # seconds to days
        elif ds_name in {"so-badges-hours", "so-badges-hours-mini"}:
            res = (1 / (60 * 60), "hours")  # seconds to hours
        elif ds_name in {"nyc-taxi", "nyc-taxi-mini", "nyc-taxi-12h"}:
            res = (1 / 60, "minutes")  # seconds to minutes
        # The PUBG data has 90% of events below 1 minute.
        elif ds_name in {"pubg"}:
            res = (60, "seconds")  # minutes to seconds
        # For taobao, the original unit was 3-hours, but we saved the
        # data in hours. >90% of deltas are < 1 minute, so changing to
        # minutes.
        elif ds_name in {
            "yelp-airport",
            "yelp-mississauga",
            "twitter",
            "taobao",
        }:
            res = (60, "minutes")  # hours to minutes
        elif ds_name in {"reddit-askscience", "reddit-politics"}:
            res = (60 * 60, "seconds")  # hours to seconds
        elif ds_name in {"wikipedia", "mooc"}:
            res = (1 / 24, "days")  # hours to days
        elif ds_name in {"lastfm", "amazon"}:
            res = (60 * 60, "unknown")  # unknown units.
        elif ds_name in {"nyc-taxi-easytpp"}:
            res = (1, "unknown (hours?)")  # mean value is 0.22
        elif ds_name in {"so-badges-easytpp"}:
            res = (1, "unknown (months?)")  # mean value is 0.97
        elif ds_name in {"yelp-toronto"}:
            res = (1 / (60 * 60), "hours")  # not 100% sure it's hours.
        else:
            raise ValueError(f"Unknown dataset: {ds_name}")
        return res

    @classmethod
    def random_split(cls, list_dts, ratio, rng=None):
        if rng is None:
            rng = np.random.default_rng(seed=123)
        indices = rng.permutation(np.arange(len(list_dts)))
        train_idxs, val_idxs, test_idxs = kdai.datasets.split(indices, ratio)
        train_dts = [list_dts[i] for i in train_idxs]
        val_dts = [list_dts[i] for i in val_idxs]
        test_dts = [list_dts[i] for i in test_idxs]
        return train_dts, val_dts, test_dts

    @classmethod
    def load_seqs(cls, ds_name, data_dir=None):
        if data_dir is None:
            data_dir = "./data"
        data_dir = Path(data_dir)
        ds_path = cls.data_path(ds_name, data_dir)

        def _load_easytpp(path, scale_by):
            """Loads EasyTPP datasets from pickled dictionaries."""

            def to_ds(path):
                with open(path, "r") as f:
                    ds = json.load(f)
                sequences = ds["arrival_times"]
                time_deltas = [
                    np.diff(np.array(seq)) * scale_by
                    for seq in sequences
                    if len(seq) > 1
                ]
                assert all(
                    np.all(dt > 0) for dt in time_deltas
                ), "All time deltas should be positive."
                return time_deltas

            return (
                to_ds(path / f"{split}.json")
                for split in ["train", "val", "test"]
            )

        def _load(path):
            scale_by, _ = cls.scale(ds_name)
            if "easytpp" in str(ds_path):
                # EasyTpp datasets have fixed splits.
                return _load_easytpp(path, scale_by)
            elif path.suffix == ".pkl":
                sequences = torch.load(
                    path, map_location="cpu", weights_only=False
                )["sequences"]
                time_deltas = [
                    np.diff(np.array(seq["arrival_times"])) * scale_by
                    for seq in sequences
                    if len(seq["arrival_times"]) > 1
                ]
            else:
                sequences = np.load(path, allow_pickle=True)["arrival_times"]
                time_deltas = [
                    np.diff(np.array(seq)) * scale_by
                    for seq in sequences
                    if len(seq) > 1
                ]
            assert all(
                np.all(dt > 0) for dt in time_deltas
            ), "All time deltas should be positive."
            train_seqs, val_seqs, test_seqs = cls.random_split(
                time_deltas, cls.tvt_ratio
            )
            return train_seqs, val_seqs, test_seqs

        _logger.info(f'Loading sequences for dataset "{ds_name}".')
        train_dts, val_dts, test_dts = _load(ds_path)
        _logger.info(f"Loaded {len(train_dts):,} train sequences.")
        _logger.info(f"Loaded {len(val_dts):,} val sequences.")
        _logger.info(f"Loaded {len(test_dts):,} test sequences.")
        return train_dts, val_dts, test_dts

    @classmethod
    def for_model_and_ds(
        cls,
        model_name: str,
        ds_name: str,
        train_lens=None,
        data_dir=None,
    ) -> Iterator[tuple[MajorRunSpec, DsMgrFn]]:
        if train_lens is not None:
            raise NotImplementedError(
                "train_lens is not supported for baseline datasets."
            )
        assert train_lens is None
        train_dts, val_dts, test_dts = cls.load_seqs(ds_name, data_dir)
        assert type(train_dts) == list and type(train_dts[0]) == np.ndarray

        if train_lens is None:
            max_len = sum(len(a) for a in train_dts)
            train_lens = [max_len]
        for n_dt in train_lens:
            train_subset = ds.trim_to(train_dts, n_dt)
            cum_lens = sum([len(a) for a in train_subset])
            assert cum_lens >= n_dt, f"{cum_lens=} < {n_dt=}"
            assert cum_lens - len(train_subset[-1]) < n_dt, (
                "Too many events. "
                f"{cum_lens=} - {len(train_subset[-1])} >= {n_dt=}"
            )
            # Note, the idea about evaluating on intervals for the real-world
            # datasets didn't make it into the paper, other than just as a
            # suggestion, so this doesn't actually get used.
            #
            # This is a proposed option. All models should be evaluated on the
            # same interval length for calculating probability masses. This
            # doesn't need to be the same as the unit used in the dataset
            # event times (e.g. seconds, hours), and it doesn't have to be,
            # and most likely shouldn't be, the same unit used by the model,
            # which should apply it's own scaling for the purposes of good
            # training dynamics. Set this interval to 1 if the evaluation
            # interval is the same as the dataset event times.
            density_interval_len = 1

            def to_data_mgr(model_in_len):
                res = ds.EventSeqListDatasets(
                    train_dts,
                    val_dts,
                    test_dts,
                    model_in_len,
                    density_interval_len=density_interval_len,
                )
                n_samples = sum(len(a) - 1 for a in train_subset)
                assert (
                    len(res.train_ds()) == n_samples
                ), f"{len(res.train_ds())} != {n_samples}"
                return res

            yield (
                MajorRunSpec(
                    model_name,
                    "baseline",
                    ds_name,
                    n_dt,
                    *cls.epoch_opts(n_dt),
                ),
                to_data_mgr,
            )

    @classmethod
    def for_model(
        cls,
        model_name: str,
        train_lens=None,
        data_dir=None,
    ) -> Iterator[tuple[MajorRunSpec, DsMgrFn]]:
        for ds_name in cls.datasets:
            yield from cls.for_model_and_ds(
                model_name, ds_name, train_lens, data_dir
            )

    @classmethod
    def epoch_opts(cls, train_len):
        max_epochs = 512
        # 4 epochs of 2**25
        total_samples = math.ceil(2**27 / train_len) * train_len
        n_epochs, _ = ds.epoch_opts(
            train_len,
            total_samples=total_samples,
            max_epochs=max_epochs,
            max_batch_size=2048,
            min_steps_per_epoch=128,
        )
        batch_size = 128
        max_steps_per_eval = 1024
        return (n_epochs, batch_size, max_steps_per_eval)


"""
RunItr is used to allow train(), eval(), lr_sweep(), etc. to work with very
different models and datasets, and to consume combinations of them.

A run iterator gives run specs and the function this spec will use to
create a dataset manager. The dataset manager is not created until needed,
as there should only need to be one in memory at a time, and also because
each model may want to specify the model_in_len. Currently, the model_in_len
argument is the only argument of the dataset manager function, but this could
be extended.
"""
RunItr = Iterator[tuple[MajorRunSpec, Callable[..., kdai.train.DatasetManager]]]


def _const_scheduler_fn(optimizer):
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lambda _: 1.0)


def lr_sweep(
    run_itr: RunItr,
    out_dir,
    n_runs,
    n_lr_steps,
    lr_min,
    lr_max,
    batch_sizes,
    warm_up_steps: Optional[int] = None,
):
    """
    Some requirements on the run specs:

        - all share the same keys when converted to a dictionary, as these keys
          will form the columns of the returned dataframe.
        - all must have a train_len that is long enough to do the lr_find;
          for example, at 1000 n_lr_steps at a 1024 batch size, we might need
          (1024 + n_skips) * 1000 samples, where n_skips is the somewhat
          unpredictable number of steps that will be skipped due to gradients
          being too large.
    Args:
        run_iter: iterator for run specs.


    Returned dataframe schema:
    |←────           run_spec keys        ─────→|
    ┌───────┬─────┬─────────────┬───────────┬───┬────────────┬─────────┬───────────┬────────┐
    | model | ds  | ds_fullname | train_len | * | batch_size |  lrs    |   loss    | n_runs |
    |  str  | str |     str     |    int    | * |    int     | [float] | [[float]] |   int  |

    At the *, any other key-value pairs from the run spec will be included.

    """
    if warm_up_steps is not None:
        if warm_up_steps <= 0:
            raise ValueError("warmup_steps must be > 0.")
        if warm_up_steps < 50:
            _logger.warning(f"warmup_steps is very low ({warm_up_steps=}).")
        warmup = True
    else:
        warmup = False

    def m_dir_fn(out_dir, run_spec, batch_size):
        return out_dir / run_spec.model_dir / str(batch_size)

    rows = []
    first_run_spec = None
    for run_spec, ds_mgr_fn in run_itr:
        # Later the first run spec is used to create the DataFrame columns.
        if first_run_spec is None:
            first_run_spec = run_spec
        else:
            if first_run_spec.to_dict().keys() != run_spec.to_dict().keys():
                raise ValueError("Run specs must have the same keys.")
        ds_name = run_spec.ds_name
        model_name = run_spec.model_name
        fuse_adam = get_train_args(ds_name, model_name).get("fuse_adam", True)
        trainable_fn = trainable_fns[model_name]
        # We need a dummy trainable to check train lengths. We will reuse this
        # trainable for warmup, if warmup is enabled.
        init_trainable = trainable_fn(ds_mgr_fn, model_mode="train-loss")
        # Check train length, as it's easy to get this wrong.
        min_train_len = n_lr_steps * max(batch_sizes)
        suggested_train_len = min_train_len * 2
        train_ds_len = len(init_trainable.train_ds())
        if train_ds_len < min_train_len:
            _logger.warning(
                f"Train length {train_ds_len} isn't long enough to cover "
                f"{n_lr_steps=} and {max(batch_sizes)=} in a single epoch."
            )
        elif train_ds_len < suggested_train_len:
            _logger.warning(
                f"Train length {train_ds_len} is possibly short for n_lr_steps="
                f"{n_lr_steps} and max(batch_sizes)={max(batch_sizes)}. "
                f"Consider using {suggested_train_len}."
            )
        if warmup:
            _logger.info("Warm up training.")
            warm_up_batch_size = 1024
            init_out_dir = out_dir / "warmup"
            assert warm_up_steps is not None
            kdai.train.train(
                init_trainable,
                n_epochs=1,
                batch_size=warm_up_batch_size,
                # Hard-coded due to how train initialized AdamW. This should be
                # reworked in train.py eventually.
                lr=lr_min * 25,
                out_dir=init_out_dir,
                fuse_adam=fuse_adam,
                samples_per_epoch=warm_up_batch_size * warm_up_steps,
                weight_decay=0.1,
                scheduler_fn=_const_scheduler_fn,
                n_workers=5,
                save_checkpoints=False,
            )
            init_weights = init_trainable.model.state_dict()
        else:
            init_weights = None

        for batch_size in batch_sizes:
            trainable_fn = trainable_fns[model_name]
            trainable = trainable_fn(ds_mgr_fn, model_mode="train-loss")
            m_dir = m_dir_fn(out_dir, run_spec, batch_size)
            logging.info(f"{m_dir} [starting]")
            m_dir.mkdir(parents=True, exist_ok=False)
            train_kwargs = get_train_args(ds_name, model_name)
            losses, lr_sched = kdai.lrfind.sweep(
                trainable,
                lr_min=lr_min,
                lr_max=lr_max,
                n_lr_steps=n_lr_steps,
                n_runs=n_runs,
                early_stopper=kdai.lrfind.lr_early_stopper(n_lr_steps),
                init_weights=init_weights,
                log_space=True,
                weight_decay=train_kwargs["weight_decay"],
                batch_size=batch_size,
                n_workers=5,
            )
            min_len = min(len(l) for l in losses)
            mloss = np.array([l[:min_len] for l in losses]).mean(axis=0)
            best_mloss_idx = np.argmin(mloss)
            logging.info(
                f"{m_dir} [finished] (best mean loss: "
                f"{mloss[best_mloss_idx]:.3f}, step: {best_mloss_idx})"
            )
            run_spec_dict = run_spec.to_dict()
            # We aren't using the run spec dict for the batch size. It's easiest
            # to just override it there.
            run_spec_dict["batch_size"] = batch_size
            rows.append(
                [
                    *run_spec_dict.values(),
                    lr_sched,
                    losses,
                    n_runs,
                ]
            )
            # Partially address parent-child dl issue.
            del trainable
            gc.collect()
    assert first_run_spec is not None
    res = pl.DataFrame(
        rows,
        schema=[
            *first_run_spec.to_dict().keys(),
            # We override the batch size in the run spec dict.
            # "batch_size",
            "lrs",
            "loss",
            "n_runs",
        ],
        orient="row",
    )
    assert len(res.filter(pl.col("loss").is_not_null())), res.head()
    return res


# Map from (model_name, batch_size, ds_name) to learning rate.
LrMap = dict[Tuple[str, int, str], float]


def load_lr_map(lr_df, lr_override_df=None):
    """Create a learning rate map from a dataframe listing learning rates.

    An override dataframe can be provided to override the learning rates.
    This is useful when an automatically produced learning rate dataframe
    has a few entries that are not suitable and need to be manually
    chosen. This is used quite a bit for the const head, which, for a few
    datasets, doesn't see a loss decrease while training, which prevents
    the learning rate finder from finding a good learning rate.

    Only overridden learning rates are listed in the "lr" column of the
    override dataframe. Whereas the main learning rate dataframe should
    have an "lr" entry for each row.
    """

    def to_dict(df):
        d = {
            (m, bs, ds): lr
            for m, bs, ds, lr in df.select(
                pl.col("model"),
                pl.col("batch_size"),
                pl.col("ds"),
                pl.col("lr"),
            ).rows()
        }
        return d

    lr_map = to_dict(lr_df)
    _logger.debug(f"Using LR map: {lr_map}")
    if lr_override_df is not None:
        # The lr_override map can have null entries, which can be error prone
        # in terms of dtype inference. Make sure the "lr" column is float64.
        if lr_override_df["lr"].dtype != pl.Float64:
            _logger.info(
                f"Changing 'lr' dtype from {lr_override_df['lr'].dtype} "
                "to float64."
            )
            lr_override_df = lr_override_df.with_columns(
                pl.col("lr").cast(pl.Float64).alias("lr")
            )
        overrides = to_dict(lr_override_df.filter(pl.col("lr").is_not_null()))
        _logger.debug(f"Using LR override map: {overrides}")
        lr_map.update(overrides)
    return lr_map


def save_lr_map(map, path):
    """Save a learning rate map to a file.

    The map is saved as a dataframe with the following columns:
    - model
    - batch_size
    - ds
    - lr
    """
    rows = []
    for (model, batch_size, ds), lr in map.items():
        rows.append((model, batch_size, ds, lr))
    df = pl.DataFrame(
        rows,
        schema=["model", "batch_size", "ds", "lr"],
        orient="row",
    )
    df.write_csv(path)


def train(
    out_dir: Path,
    run_itr: RunItr,
    model_mode: ModelMode = "train-info",
    use_early_stopping=False,
    lr_map: LrMap | None = None,
    skip_until_m_dir=None,
    log_activations=False,
    num_workers=8,
):
    if not model_mode in {"train-info", "train-loss", "train-info2"}:
        raise ValueError(f"Must use a training mode.")
    still_skipping = skip_until_m_dir is not None
    for run_spec, ds_mgr_fn in run_itr:
        if still_skipping:
            if str(run_spec.model_dir) == str(skip_until_m_dir):
                still_skipping = False
            else:
                _logger.info(
                    f"Skipping {run_spec.model_dir} (until: {skip_until_m_dir})"
                )
                continue
        m_dir = out_dir / run_spec.model_dir
        trainable_fn = trainable_fns[run_spec.model_name]
        trainable = trainable_fn(ds_mgr_fn, model_mode)
        train_kwargs = get_train_args(run_spec.ds_name, run_spec.model_name)
        logging.info(f"Model dir: {m_dir}")
        if use_early_stopping:
            total_steps = (
                run_spec.n_epochs
                * len(trainable.train_ds())
                / run_spec.batch_size
            )
            early_stopper = kdai.train.EarlyStopper(
                # You must at least train for half the total steps.
                # The reason being that the learning rate peaks at around 1/3
                # total steps, and we should wait until it comes down before
                # considering early stopping.
                min_steps=total_steps // 2,
                eval_patience=12,
            )
        else:
            early_stopper = None
        _logger.info(
            f"model: {run_spec.model_name}, ds: {run_spec.ds_name}, "
            f"n_epochs: {run_spec.n_epochs}"
        )
        # **kwargs overrides named arguments. Options from get_train_args()
        # shouldn't override the run_spec values, so set them into the dict.
        train_kwargs["n_epochs"] = run_spec.n_epochs
        train_kwargs["batch_size"] = run_spec.batch_size
        train_kwargs["steps_til_eval"] = run_spec.steps_til_eval
        if model_mode == "train-info2":
            train_kwargs["evals_til_eval_train_ds"] = 2
        if lr_map is not None:
            train_kwargs["lr"] = lr_map[
                (run_spec.model_name, run_spec.batch_size, run_spec.ds_name)
            ]
        else:
            _logger.info(f"Using lr: {train_kwargs['lr']}")
        # Uncomment and indent to enable a profiler.
        # import torch.profiler as tp
        # with tp.profile(
        #     activities=[tp.ProfilerActivity.CPU, tp.ProfilerActivity.CUDA],
        #     on_trace_ready=tp.tensorboard_trace_handler(
        #         m_dir / "profiler_trace"
        #     ),
        # ) as prof:
        train_kwargs["n_workers"] = num_workers
        try:
            kdai.train.train(
                trainable,
                out_dir=m_dir,
                early_stopper=early_stopper,
                log_activations=log_activations,
                **train_kwargs,
            )
        except kdai.train.ModelException as e:
            _logger.error(f"Caught ModelException:\n{e}")
            _logger.error(f"Moving on to next model.")
        # Can help fix the parent-child dl issue when using persistent dls.
        del trainable
        gc.collect()

    logging.info("Finished training.")


@torch.no_grad()
def eval(
    out_dir: Path,
    run_itr: RunItr,
    batch_size=2048,
    eval_len=None,
    use_test_ds=False,
):
    torch.set_float32_matmul_precision("high")
    all_metrics = []
    run_spec_keys = None
    for run_spec, ds_mgr_fn in run_itr:
        # Save the first run spec keys, as we will use them for the schema.
        if run_spec_keys is None:
            run_spec_keys = run_spec.to_dict().keys()
        m_dir = out_dir / run_spec.model_dir

        trainable_fn = trainable_fns[run_spec.model_name]
        trainable = trainable_fn(ds_mgr_fn, model_mode="eval-metrics")
        ds = trainable.test_ds() if use_test_ds else trainable.val_ds()
        dl = torch.utils.data.DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=False,
            num_workers=20,
            drop_last=False,
            pin_memory=True,
            # No need for persistence, as we only evaluate once.
            persistent_workers=False,
        )
        # Constrain the dataloader to the eval_len.
        if eval_len is None:
            eval_len = eval_opts("eval-metrics")["eval_len"]
        _logger.info(f"Constraining dataloader to {eval_len=}.")
        dl = kdai.datasets.ConstrainedIterable(dl, eval_len)
        ckpt_path = m_dir / "checkpoint_best_loss.pth"
        metrics = run_spec.to_dict()
        _logger.info(
            f"[start] evaluating model: {run_spec.model_name}, "
            f"ds: {run_spec.ds_fullname()} ({['val', 'test'][use_test_ds]})"
        )
        metrics.update(eval_single(trainable, ckpt_path, dl))
        all_metrics.append(metrics)
        del trainable
        del dl
        gc.collect()
    assert run_spec_keys is not None
    _logger.info("[finished] evaluation.")
    df = pl.from_dicts(
        all_metrics,
        schema=[
            *run_spec_keys,
            ("loss", pl.Float32),
            ("pred_nll", pl.Float32),
            ("interval_pred_nll", pl.Float32),
            ("mean_abs_err", pl.Float32),
            ("mean_abs_err_mode", pl.Float32),
            ("mean_abs_err_median", pl.Float32),
        ],
    )
    return df


@torch.no_grad()
def eval_for_spikes(
    out_dir,
    run_itr,
    pred_stride,
    sigma_ms,
    batch_size=1024,
    n_workers=14,
    use_test_ds=False,
    ll_stride=1,
):
    torch.set_float32_matmul_precision("high")
    run_spec_keys = None
    bin_ms = 1000 / 992
    for run_spec, ds_mgr_fn in run_itr:
        # Save the first run spec keys, as we will use them for the schema.
        if run_spec_keys is None:
            run_spec_keys = run_spec.to_dict().keys()
        m_dir = out_dir / run_spec.model_dir
        ckpt_path = m_dir / "checkpoint_best_loss.pth"
        trainable_fn = trainable_fns[run_spec.model_name]
        trainable = trainable_fn(ds_mgr_fn, model_mode="eval-metrics")
        trainable.model.cuda()
        trainable.model.eval()

        kdai._logging.load_model(
            trainable.model, ckpt_path, torch.device("cuda")
        )
        ds = trainable.test_ds() if use_test_ds else trainable.val_ds()

        recs = ds.recordings
        all_metrics = []

        # 1. Calc interval log-likelihood for all recordings and all cells.
        _logger.info("[start] rec_ll")
        if type(trainable).__name__ in ["SpikesDiscrete", "LogMixForSpikes"]:
            ll_by_cid = inferspikes.rec_ll(
                trainable,
                recs,
                pred_stride,
                batch_size,
                n_workers,
                reduction="mean",
                stride=ll_stride,
            )
        else:
            ll_by_cid = None
        # 2. Calc autoregressive metrics (van rossum, etc. and ll).
        _logger.info("[start] prob_auto_stats")
        for rec in recs:
            for cid in tqdm(rec.cell_ids):
                rec_single_cid = rec.cells({cid})
                auto_stats = inferspikes.auto_stats(
                    trainable,
                    rec_single_cid,
                    pred_stride,
                    bin_ms=bin_ms,
                    sigma_ms=sigma_ms,
                )
                if ll_by_cid is None:
                    ll = float("nan")
                elif cid not in ll_by_cid:
                    _logger.warning(f"Missing {cid=}. Not enough spikes?")
                    ll = float("nan")
                else:
                    ll = ll_by_cid[cid]
                per_cell_metrics = {
                    **run_spec.to_dict(),
                    "rec_name": rec.name,
                    "split": rec.metadata.get("split", None),
                    "cid": cid,
                    "van_rossum": auto_stats.van_rossum,
                    "pcorr": auto_stats.pcorr,
                    "schreiber": auto_stats.schreiber,
                    "auto_ll": auto_stats.ll,
                    "ll": ll,  # This value is the same regardless of split.
                    "n_pred": auto_stats.n_pred,
                    "n_gt": auto_stats.n_gt,
                    "gt_spikes": auto_stats.gt_spikes,
                    "pred_spikes": auto_stats.pred_spikes,
                    "pred_start": auto_stats.pred_start,
                    "pred_end": auto_stats.pred_end,
                    "stride": pred_stride,
                    "bin_ms": bin_ms,
                    "sigma_ms": sigma_ms,
                    "test_ds": use_test_ds,
                }
                all_metrics.append(per_cell_metrics)
        df = pl.from_dicts(
            all_metrics,
            schema=[
                *run_spec_keys,
                ("rec_name", pl.String),
                ("split", pl.UInt8),
                ("cid", pl.UInt32),
                ("van_rossum", pl.Float32),
                ("pcorr", pl.Float32),
                ("schreiber", pl.Float32),
                ("auto_ll", pl.Float32),
                ("ll", pl.Float32),
                ("n_pred", pl.UInt32),
                ("n_gt", pl.UInt32),
                ("gt_spikes", pl.List(pl.Int32)),
                ("pred_spikes", pl.List(pl.Int32)),
                ("pred_start", pl.Int32),
                ("pred_end", pl.Int32),
                ("stride", pl.Float32),
                ("bin_ms", pl.Float32),
                ("sigma_ms", pl.Float32),
                ("test_ds", pl.Boolean),
            ],
        )
        return df


@torch.no_grad()
def eval_for_spikesv2(
    out_dir,
    run_itr,
    pred_stride,
    sigma_ms,
    batch_size=2048,
    n_workers=14,
    use_test_ds=False,
    ll_stride=1,
):
    """Parallel version."""
    torch.set_float32_matmul_precision("high")
    run_spec_keys = None
    bin_ms = 1000 / 992
    for run_spec, ds_mgr_fn in run_itr:
        # Save the first run spec keys, as we will use them for the schema.
        if run_spec_keys is None:
            run_spec_keys = run_spec.to_dict().keys()
        m_dir = out_dir / run_spec.model_dir
        ckpt_path = m_dir / "checkpoint_best_loss.pth"
        trainable_fn = trainable_fns[run_spec.model_name]
        trainable = trainable_fn(ds_mgr_fn, model_mode="eval-metrics")
        trainable.model.cuda()
        trainable.model.eval()

        kdai._logging.load_model(
            trainable.model, ckpt_path, torch.device("cuda")
        )
        ds = trainable.test_ds() if use_test_ds else trainable.val_ds()

        recs = ds.recordings
        all_metrics = []

        # 1. Calc interval log-likelihood for all recordings and all cells.
        _logger.info("[start] rec_ll")
        if type(trainable).__name__ in ["SpikesDiscrete", "LogMixForSpikes"]:
            ll_by_cid = inferspikes.rec_ll(
                trainable,
                recs,
                pred_stride,
                batch_size,
                n_workers,
                reduction="mean",
                stride=ll_stride,
            )
        else:
            ll_by_cid = None
        # 2. Calc autoregressive metrics (van rossum, etc. and ll).
        _logger.info("[start] prob_auto_stats")
        for rec in tqdm(recs):
            rate_lims = np.array(
                [
                    list(trainable.ds_mgr.cell_max_rate[gid].values())  # [50]
                    for gid in rec.cell_gids
                ]
            )
            rate_wins = np.array(
                [
                    list(trainable.ds_mgr.cell_max_rate[gid].keys())
                    for gid in rec.cell_gids
                ]
            )
            # auto_stats2 will take the mins in as an array, assuming all
            # are in order from 1 to n.
            cell_dt_min = np.array(
                [
                    list(trainable.ds_mgr.cell_dt_min[gid].values())
                    for gid in rec.cell_gids
                ]
            )
            stats_by_cell = inferspikes.auto_stats2(
                trainable,
                rec,
                pred_stride,
                bin_ms=bin_ms,
                sigma_ms=sigma_ms,
                cell_dt_min=cell_dt_min,
                cell_dt_max=None,
                rate_wins=rate_wins,
                rate_lims=rate_lims,
            )
            assert rec.num_cells() == len(stats_by_cell)
            for cid, stats in zip(rec.cell_ids, stats_by_cell):
                if ll_by_cid is None:
                    ll = float("nan")
                elif cid not in ll_by_cid:
                    _logger.warning(f"Missing {cid=}. Not enough spikes?")
                    ll = float("nan")
                else:
                    ll = ll_by_cid[cid]
                per_cell_metrics = {
                    **run_spec.to_dict(),
                    "rec_name": rec.name,
                    "split": rec.metadata.get("split", None),
                    "cid": cid,
                    "van_rossum": stats.van_rossum,
                    "pcorr": stats.pcorr,
                    "schreiber": stats.schreiber,
                    "auto_ll": stats.ll,
                    "ll": ll,  # This value is the same regardless of split.
                    "n_pred": stats.n_pred,
                    "n_gt": stats.n_gt,
                    # Saving memory.
                    "gt_spikes": None,  # stats.gt_spikes.tolist(),
                    "pred_spikes": None,  # stats.pred_spikes.tolist(),
                    "pred_start": stats.pred_start,
                    "pred_end": stats.pred_end,
                    "stride": pred_stride,
                    "bin_ms": bin_ms,
                    "sigma_ms": sigma_ms,
                    "test_ds": use_test_ds,
                }
                all_metrics.append(per_cell_metrics)
        del trainable
        gc.collect()
        df = pl.from_dicts(
            all_metrics,
            schema=[
                *run_spec_keys,
                ("rec_name", pl.String),
                ("split", pl.UInt8),
                ("cid", pl.UInt32),
                ("van_rossum", pl.Float32),
                ("pcorr", pl.Float32),
                ("schreiber", pl.Float32),
                ("auto_ll", pl.Float32),
                ("ll", pl.Float32),
                ("n_pred", pl.UInt32),
                ("n_gt", pl.UInt32),
                ("gt_spikes", pl.List(pl.Int32)),
                ("pred_spikes", pl.List(pl.Int32)),
                ("pred_start", pl.Int32),
                ("pred_end", pl.Int32),
                ("stride", pl.Float32),
                ("bin_ms", pl.Float32),
                ("sigma_ms", pl.Float32),
                ("test_ds", pl.Boolean),
            ],
        )
        return df


@torch.no_grad()
def eval_single(trainable, ckpt_path, val_dl):
    """
    Calculates nll, mae-mean, mae-mode, mae-median.
    """
    kdai._logging.load_model(trainable.model, ckpt_path, torch.device("cuda"))
    trainable.model.cuda()
    trainable.model.eval()
    all_metrics = {
        "loss": None,
        "pred_nll": None,
        "mean_abs_err": None,
        "mean_abs_err_mode": None,
        "mean_abs_err_median": None,
        "interval_pred_nll": None,
    }
    all_metrics.update(trainable.eval_metrics(val_dl))

    return all_metrics


def map_multi_rand_proc_labels(df):
    """
    From cols ['label_0', 'label_1'...] add cols ['model', 'ds', 'ds_size'...]
    """
    ptrn = r"^([a-z-\d]+)-(\d+)$"

    def to_name_and_size(ds_label):
        match = re.match(ptrn, ds_label)
        assert match, ds_label
        name = match.group(1)
        size = int(match.group(2))
        return {"ds": name, "train_len": size}

    df = df.with_columns(
        [
            pl.col("label_1")
            .map_elements(to_name_and_size, return_dtype=pl.Struct)
            .alias("temp"),
            pl.col("label_0").alias("model"),
        ]
        # For some reason, this must be done in two steps.
    ).with_columns(pl.col("temp").struct.field("*"))
    assert df["model"].n_unique() == 1
    return df


def multi_len_rand_process_df():
    """Dataframe of the training settings for multi-length random processes.

    Used to make table for report.

    ┌───────────┬────────────┬──────────┬─────────┬─────────┐
    | train_len | batch_size | n_epochs | n_steps | n_evals |

    """
    train_lens = np.array(RandProc.default_train_lens)
    n_epochs, batch_size, max_steps_per_eval = [
        np.array(l)
        for l in zip(*[RandProc.epoch_opts(tl) for tl in train_lens])
    ]
    n_steps = (train_lens * n_epochs / batch_size).astype(np.int32)
    n_evals = (np.maximum(n_epochs, n_steps // max_steps_per_eval)).astype(
        np.int32
    )
    df = pl.DataFrame(
        {
            "Train length": train_lens,
            "Batch size": batch_size,
            "Epochs": n_epochs,
            "Steps": n_steps,
            "Evals": n_evals,
        }
    )
    return df
