import functools
from typing import Dict, List, Literal, Sequence, Tuple
import torch
import einops
import torchinfo
import kdtpp.spikedistance as sdf
import kdai._logging
import kdai.train
import kdai.datasets
import kdtpp.metrics
import polars as pl
import plotly
import kdtpp.mea as mea
import kdtpp.inferspikes
import numpy as np
from collections import defaultdict


class DistTrainable(kdai.train.BaseTrainable):
    """
    A base trainable for the distance models.

    There are quite a few things that are common to all such models, such as:
        - the conversion from distance array to model output (and vice versa)
        - the requirement to wrap properties like sample rate that would
          otherwise not be contained in one obvious place.
    """

    DEFAULT_REFACTORY_MS = 2

    def __init__(
        self,
        ds_mgr: mea.DistDatasets,
        model,
        label,
        eval_mode: Literal["loss", "info"] = "info",
        eval_len=int(5e4),
        init_weights=True,
        eval_rec_cids=None,
    ):
        super().__init__(ds_mgr, model, label)
        self.eval_mode = eval_mode
        self.eval_len = eval_len
        self.refactory_len = self.ms_to_bins(self.DEFAULT_REFACTORY_MS)
        self.eval_rec_cids = eval_rec_cids
        if init_weights:
            self.init_weights()
        # Not great abstraction here, but we aren't exactly making a library.
        self.model_in_len = self.model.gpt_base.input_len

    def init_weights(self):
        log_dist_mean, log_dist_sd = self.ds_mgr.log_dist_mean_sd()
        # TEMP
        # self.model.set_output_mean_sd(log_dist_mean, log_dist_sd)
        self.model.set_output_mean(log_dist_mean)
        self.model.set_input_mean_sd(
            torch.full((5,), 0.5),
            torch.tensor([0.5, 0.5, 0.5, 0.5, 1]),
        )

    def ms_to_bins(self, ms: float) -> int:
        num_bins = max(1, round(ms * (self.sample_rate / 1000)))
        return num_bins

    @property
    def dist_prefix_len(self):
        return self.ds_mgr.dist_prefix_len

    @functools.cached_property
    def max_bin_dist(self):
        return self.ms_to_bins(self.ds_mgr.DIST_CLAMP_MS)

    @functools.cached_property
    def sample_period_ms(self):
        return 1000 / self.ds_mgr.sample_rate

    @property
    def sample_rate(self):
        return self.ds_mgr.sample_rate

    @staticmethod
    def dist_to_nn_output(dist):
        return torch.log(dist)

    @staticmethod
    def nn_output_to_dist(nn_output):
        return torch.exp(nn_output)

    def out_mean_sd(self):
        res = (self.model.output_mean.item(), self.model.output_scale.item())
        return res

    def loss_fn(
        self, m_dist, t_dist
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Args:
            - m_dist: model output, log space
            - t_dist: target distance, linear, from sample
        """
        t_dist = self.dist_to_nn_output(t_dist)
        batch_size = m_dist.shape[0]
        batch_sum = sdf.dist_loss(m_dist, t_dist)
        batch_ave = batch_sum / batch_size
        return batch_ave, {}

    def _forward(
        self, sample
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Forward pass, with loss calculation.

        Returns a tuple:
            - model output
            - total loss
            - named sub-losses
        """
        masked_snippet = sample["snippet"].float().cuda()
        tdist = sample["dist"].float().cuda()
        cell_id = sample["cell_id"].long().cuda()
        m_out = self.model(masked_snippet, cell_id)
        total_loss, named_losses = self.loss_fn(m_out, t_dist=tdist)
        return m_out, total_loss, named_losses

    def forward(self, sample):
        m_out, total_loss, _ = self._forward(sample)
        return m_out, total_loss

    def forward_no_loss(self, sample):
        """Slightly faster, if loss calc is not insignificant."""
        masked_snippet = sample["snippet"].float().cuda()
        cell_id = sample["cell_id"].long().cuda()
        model_output = self.model(masked_snippet, cell_id)
        return model_output

    def infer_dist(self, in_snippet, cell_id):
        in_snippet = in_snippet.float().cuda()
        cell_id = cell_id.long().cuda()
        model_output = self.model(in_snippet, cell_id)
        dist = self.nn_output_to_dist(model_output)
        return dist

    def evaluate_train(self, dl_fn):
        dl = kdai.datasets.ConstrainedIterable(
            dl_fn(self.ds_mgr.train_ds()), self.eval_len
        )
        # Evaluate on the train ds currently only supports basic eval.
        loss_metrics = self.calc_loss(dl)
        results = {"metrics": loss_metrics}
        return results

    def evaluate_val(self, dl_fn):
        if self.eval_mode == "loss":
            dl = kdai.datasets.ConstrainedIterable(
                dl_fn(self.ds_mgr.val_ds()), self.eval_len
            )
            loss_metrics = self.calc_loss(dl)
            results = {"metrics": loss_metrics}
            return results
        elif self.eval_mode == "info":
            res = self.evaluate_val_detail(dl_fn, full_metrics=False)
        else:
            raise ValueError(f"Unknown eval mode: {self.eval_mode}")
        return res

    def _eval_recs(self, ds):
        if self.eval_rec_cids is None or len(self.eval_rec_cids) == 0:
            # Just take the first cell of the first recording.
            eval_recs = [
                ds.recordings[0].cells(set(ds.recordings[0].cell_ids[0]))
            ]
        else:
            rec_to_cid = defaultdict(set)
            for r_name, c_id in self.eval_rec_cids:
                rec_to_cid[r_name].add(c_id)
            eval_recs = []
            for r in ds.recordings:
                if r.name in rec_to_cid:
                    rec = r.cells(rec_to_cid[r.name])
                    if len(rec.cell_ids) == 0:
                        raise ValueError(
                            f"Cells ({rec_to_cid[r.name]}) not in rec {r.name}"
                        )
                    eval_recs.append(rec)
        return eval_recs

    def _spike_train_metrics(self, ds, stride):
        recs = self._eval_recs(ds)
        res = {}
        for rec in recs:
            for cid in rec.cell_ids:
                r = rec.cells({cid})
                assert r.cell_ids == [cid], r.cell_ids
                pred, s, e = kdtpp.inferspikes.from_dist(self, r, stride)
                gt = einops.rearrange(r.spikes[s:e], "l 1 -> l", l=len(pred))
                van_rossum = kdtpp.metrics.van_rossum(
                    gt, pred, bin_ms=1000 / 992, tau_ms=60
                )
                pcorr = kdtpp.metrics.smooth_pcorr(
                    gt, pred, bin_ms=1000 / 992, sigma_ms=60
                )
                schreiber = kdtpp.metrics.schreiber(
                    gt, pred, bin_ms=1000 / 992, sigma_ms=60
                )
                if len(recs) > 1:
                    tag = f"r{r.name}_c{cid}"
                else:
                    tag = f"c{cid}"
                res[f"van_rossum_τ60_{tag}"] = van_rossum
                res[f"pcorr_σ60_{tag}"] = pcorr
                res[f"schreiber_σ60_{tag}"] = schreiber
        return res

    def evaluate_val_detail(self, dl_fn, full_metrics=False):
        """
        Calculate many metrics and produce figures.

        A long stride is used for quicker results, so this method is not
        suitable for model selection. Actually, two strides are used, one
        for the loss calc, and one for the other metrics and figures.
        There are two strides as the metrics' calculation requires a known
        stride in miliseconds and may be set longer than the stride for loss.
        """
        dl = kdai.datasets.ConstrainedIterable(
            dl_fn(self.ds_mgr.val_ds()), self.eval_len
        )
        # A long stride for loss calc.
        loss_metrics = self.calc_loss(dl)

        # Below is old code that manually sets the stride. Need to figure out
        # which usages are needed, and which can be delegated to the 
        # ConstrainedIterable.

        # A reasonably long stride, set in terms of miliseconds for metrics.
        strided_ds = self.ds_mgr.val_ds()
        stride_ms = 60
        strided_ds.stride = self.ms_to_bins(stride_ms)
        assert strided_ds.stride == 60
        # Copy the dl options from the dl_fn()'s output dl.
        metrics = loss_metrics
        if full_metrics:
            dist_spikes_tuple = _calc_output_arrays(
                self,
                strided_ds,
                dl.batch_size,
                self.max_bin_dist,
                self.refactory_len,
                dl.num_workers,
                dl.pin_memory,
            )
            metrics += detailed_metrics(
                *dist_spikes_tuple,
                bin_ms=self.sample_period_ms,
            )
        # A stride that is small enough to produce a smooth video.
        strided_ds.stride = 10
        # frames = [t, x,  t,  t,  t,  x,   t,   t,   x,   x]
        frames = [0, 43, 47, 48, 50, 99, 103, 106, 107, 108]
        fig_frames = self.create_io_frames(
            strided_ds,
            frames=frames,
            batch_size=dl.batch_size,
            # Making the video seems to be a bit more memory intensive per
            # worker. Or maybe there is a leak?
            num_workers=dl.num_workers - 3,
        )

        results = {
            "metrics": metrics,
            # "model-io": video,
            "model-io-frames": kdai._logging.PlotlyFigureList(fig_frames),
        }
        # Spike train metrics.
        stride = self.ms_to_bins(stride_ms)
        spike_metrics = self._spike_train_metrics(self.ds_mgr.val_ds(), stride)
        spike_metrics = [
            kdai._logging.Metric(k, v) for k, v in spike_metrics.items()
        ]
        results["metrics"].extend(spike_metrics)
        return results

    def calc_loss(self, dl) -> List[kdai._logging.Metric]:
        """
        Returns one or more loss metrics.

        The first loss metric must be the total loss.
        """
        loss_meter = kdai._logging.Meter("loss")
        other_meters = {}
        it = iter(dl)
        _, loss, named_outputs = self._forward(next(it))
        loss_meter.update(loss.item())
        for name, output in named_outputs.items():
            other_meters[name] = kdai._logging.Meter(name)
            other_meters[name].update(output.item())
        for sample in it:
            _, loss, named_outputs = self._forward(sample)
            loss_meter.update(loss.item())
            for name, output in named_outputs.items():
                other_meters[name].update(output.item())
        metrics = [kdai._logging.loss_metric(loss_meter.avg)]
        for name, meter in other_meters.items():
            metrics.append(kdai._logging.Metric(name, meter.avg))
        return metrics

    @torch.no_grad()
    def create_io_frames(
        self, ds, frames: Sequence[int], batch_size, num_workers
    ):
        """
        Args:
            batch_size: the batch size to use for the dataloader.
            num_workers: the number of workers to use for the dataloader.
        """

        class FrameDs(torch.utils.data.Dataset):
            def __init__(self, ds, frames):
                self.ds = ds
                self.frames = frames

            def __getitem__(self, idx):
                return self.ds[self.frames[idx]]

            def __len__(self):
                return len(self.frames)

        dl = torch.utils.data.DataLoader(
            FrameDs(ds, frames),
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
        )
        m_out = []
        lhs_spikes = []
        target_dist = []
        target_spikes = []
        in_spikes = []

        for sample in dl:
            in_spikes.append(sample["snippet"][:, -1].float())
            model_output = self.forward_no_loss(sample)
            m_out.append(model_output)
            target_dist.append(self.dist_to_nn_output(sample["dist"].float()))
            target_spikes.append(sample["target_spikes"].float())
            lhs_spikes.append(
                sdf.lhs_spike(
                    sample["snippet"][:, -1].float(), self.max_bin_dist
                )
            )
        del dl

        infer_inputs = [
            torch.cat(in_spikes).cpu(),
            torch.cat(target_spikes).cpu(),
            torch.cat(target_dist).cpu(),
            torch.cat(m_out).cpu().clone(),
            torch.cat(lhs_spikes).cpu(),
        ]

        figs = [
            io_frame_fn(
                self,
                ds.dist_prefix_len,
                infer_inputs[0][i],
                infer_inputs[1][i],
                infer_inputs[2][i],
                infer_inputs[3][i],
                infer_inputs[4][i].item(),
                f"frame: {i:04d}",
            )
            for i in range(len(frames))
        ]
        return figs

    def model_summary(self, batch_size: int):
        dl = torch.utils.data.DataLoader(self.train_ds(), batch_size=batch_size)
        sample = next(iter(dl))
        masked_snippet = sample["snippet"].float().cuda()
        cell_id = sample["cell_id"].long().cuda()
        res = torchinfo.summary(
            self.model,
            input_data=(masked_snippet, cell_id),
            col_names=["input_size", "output_size", "mult_adds", "num_params"],
            device=self.in_device(),
            depth=4,
        )
        return res


@torch.no_grad()
def _calc_output_arrays(
    trainable,
    ds,
    batch_size,
    max_bin_dist,
    refactory_len,
    num_workers,
    pin_memory,
):
    """
    Concats model output and inferred spikes from multiple forward() calls.

    Each distance array and spike prediction array will be clipped to
    a length equal to the dataset's stride, so that concatenation maintains
    the integrity of the time axis.

    Currently, the distance array is clipped as [0:ds.stride], but it may
    be preferable to clip like:

        [ds.dist_prefix_len:ds.dist_prefix_len + ds.stride]

    """
    has_dist_to_out_fn = hasattr(trainable, "dist_to_nn_output") and callable(
        trainable.dist_to_nn_output
    )
    has_forward_no_loss_fn = hasattr(trainable, "forward_no_loss") and callable(
        trainable.forward_no_loss
    )
    if not has_dist_to_out_fn:
        raise ValueError(
            "The trainable must have a dist_to_nn_output() method."
        )
    if not has_forward_no_loss_fn:
        raise ValueError("The trainable must have a forward_no_loss() method.")
    only_one_cid = ds.recording.num_clusters() == 1
    if not only_one_cid:
        raise ValueError(
            "Only 1 cluster is supported. The recording had "
            f"({ds.recording.num_clusters()})."
        )
    dist_actual = []
    dist_pred = []
    lhs_spikes = []
    spikes_actual = []
    stride = ds.stride
    dl = torch.utils.data.DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )
    for i, sample in enumerate(dl):
        actual_dist = trainable.dist_to_nn_output(sample["dist"].float())
        dist_actual.append(actual_dist)
        spikes_actual.append(sample["target_spikes"].float())
        model_output = trainable.forward_no_loss(sample).cpu()
        dist_pred.append(model_output)
        lhs_spikes.append(
            sdf.lhs_spike(sample["snippet"][:, -1].float(), max_bin_dist)
        )
    del dl
    # Convert to 2D array [N, len]
    dist_actual = torch.cat(dist_actual, dim=0)
    dist_pred = torch.cat(dist_pred, dim=0)
    lhs_spikes = torch.cat(lhs_spikes, dim=0)
    spikes_actual = torch.cat(spikes_actual, dim=0)

    def infer_fn(i):
        s_pred = sdf.predict(
            dist_pred[i],
            int(lhs_spikes[i].item()),
            max_bin_dist,
            dist_prefix_len=ds.dist_prefix_len,
            refactory=refactory_len,
        )[:stride]
        return s_pred

    # I couldn't get the multiprocessing to work, so just use plain
    # python.
    spike_pred_len = len(dist_pred[0]) - ds.dist_prefix_len
    if spike_pred_len > 0:
        spikes_pred = list(map(infer_fn, range(len(dist_pred))))
    else:
        spikes_pred = [torch.zeros(size=(stride,))]

    # Convert [N, len] to [N, stride], then flatten.
    dist_actual = dist_actual[:, 0:stride].flatten()
    dist_pred = dist_pred[:, 0:stride].flatten()
    spikes_actual = spikes_actual[:, 0:stride].flatten()
    spikes_pred = torch.cat(spikes_pred, dim=0)
    return dist_actual, dist_pred, spikes_actual, spikes_pred


# Python requires functions to be module scope in order to pickle, which is
# a requirement for multiprocessing.
def io_frame_fn(
    obj: DistTrainable,
    dist_prefix_len: int,
    in_spikes,
    target_spikes,
    target_dist,
    model_out,
    lhs_spike: int,
    title: str,
) -> plotly.graph_objects.Figure:
    """
    Function to be run by pool.
    """
    pred_len = len(target_dist) - dist_prefix_len
    if pred_len > 0:
        pred_spikes = sdf.predict(
            # Don't forget, the prediction function takes the normal linear
            # distance.
            obj.nn_output_to_dist(model_out),
            lhs_spike,
            max_dist=obj.max_bin_dist,
            dist_prefix_len=dist_prefix_len,
            refactory=obj.refactory_len,
        )
    else:
        # Zero spike prediction length may occur if we are only interested in
        # predicting past distance field, which is something that may be done for
        # pretraining.
        pred_spikes = torch.tensor([])

    fig = dist_model_out(
        in_spikes=in_spikes.cpu().numpy(),
        target_spikes=target_spikes.cpu().numpy(),
        target_dist=target_dist.cpu().numpy(),
        model_out=model_out.cpu().numpy(),
        pred_spikes=pred_spikes.cpu().numpy(),
        x_start_ms=-200,
        dist_prefix_len=dist_prefix_len,
        bin_duration_ms=obj.sample_period_ms,
        title=title,
    )

    return fig


def dist_model_out(
    in_spikes: np.ndarray,
    target_spikes: np.ndarray,
    target_dist: np.ndarray,
    model_out: np.ndarray,
    pred_spikes: np.ndarray,
    x_start_ms: float,
    dist_prefix_len: int,
    bin_duration_ms: float,
    title=None,
):
    """Plot comparing actual and output distance array.

      +--------------------+
      | spikes |  _ dist   |
      +--------------------+
      ^        ^  ^
    x_start    |  |
               | x0=t0
       (x0 - dist_prefix_len)

       Example:
           (x_start, dist_prefix_len) = (-200, 26)

    """
    # Settings
    line_width = 1.0
    # Calculate the shared x-axis range.
    if dist_prefix_len < 0:
        raise ValueError(
            "The distance prefix length should be positive. "
            f"Got ({dist_prefix_len})."
        )
    dist_prefix_ms = dist_prefix_len * bin_duration_ms
    if x_start_ms > -dist_prefix_ms:
        raise ValueError(
            "The x-axis should start at a bin before the "
            "beginning of the distance output. (x_start_ms: "
            f"{x_start_ms}, dist_prefix (ms): {dist_prefix_ms})"
        )
    x_end_ms = math.ceil((len(model_out) - dist_prefix_len) * bin_duration_ms)
    dist_xs = (np.arange(len(model_out)) - dist_prefix_len) * bin_duration_ms

    fig = go.Figure()

    def add_dist():
        fig.add_trace(
            go.Scatter(
                x=dist_xs,
                y=model_out,
                name="pred",
                mode="lines",
                line_color="gray",
                line_width=line_width,
            )
        )
        fig.add_trace(
            go.Scatter(
                x=dist_xs,
                y=target_dist,
                name="actual",
                mode="lines",
                line_color="tomato",
                line_width=line_width,
            )
        )

    def add_actual_spikes():
        x_start_idx = math.ceil(x_start_ms / bin_duration_ms)
        actual_spikes = np.concatenate([in_spikes[x_start_idx:], target_spikes])
        index_of_spikes = np.flatnonzero(actual_spikes > 0)
        #
        #      |       |-----remainder------|
        # x_start_ms          (x_start_idx * bin_duration_ms)
        spike_start_ms = x_start_idx * bin_duration_ms
        spike_loc_ms = index_of_spikes * bin_duration_ms + spike_start_ms
        for loc_ms in spike_loc_ms:
            assert x_start_ms <= loc_ms <= x_end_ms, (
                f"Spike location must be within ({x_start_ms}, {x_end_ms})."
                f" Got {loc_ms}."
            )
            fig.add_vline(
                x=loc_ms,
                line_color="tomato",
                line_width=line_width,
                line_dash="dot",
            )

    def add_pred_spikes():
        index_of_spikes = np.flatnonzero(pred_spikes > 0)
        spike_loc_ms = index_of_spikes * bin_duration_ms
        for loc_ms in spike_loc_ms:
            assert x_start_ms <= loc_ms <= x_end_ms, (
                f"Spike location must be within ({x_start_ms}, {x_end_ms})."
                f" Got {loc_ms}."
            )
            fig.add_vline(
                x=loc_ms,
                line_color="gray",
                line_width=line_width,
                line_dash="dot",
            )

    def add_zone():
        fig.add_vrect(
            x0=0,
            x1=len(pred_spikes) * bin_duration_ms,
            fillcolor="aqua",
            opacity=0.1,
            line_width=0,
            layer="below",
        )

    def add_title():
        fig.update_layout(
            {
                "title": {
                    "text": (
                        f'<span style="font-size:75%">{title}</span><br>'
                    ),
                }
            }
        )

    def layout():
        fig.update_layout(default_fig_layout())
        fig.update_layout(
            {
                "showlegend": False,
                "height": 300,
                "width": 800,
                "xaxis": {"range": [x_start_ms, x_end_ms]},
                "yaxis": {"range": [-1.0, 5.5]},
            }
        )

    add_dist()
    add_actual_spikes()
    add_pred_spikes()
    add_zone()
    add_title()
    layout()

    return fig



def detailed_metrics(
    dist_actual,
    dist_pred,
    spikes_actual,
    spikes_pred,
    bin_ms,
    label_prefix=None,
) -> List[kdai._logging.Metric]:
    """
    Calculates metrics related to the model's output.

    Returns a dictionary that can be merged with other results.

    The results include:
        - Pearson correlation for various degrees of smoothing.
        - Schreiber correlation for various degrees of smoothing.
        - The output distance array is split into 3 chunks
          (beginning, middle, end) then calculating Pearson correlation
          (no smoothing, as we are dealing with the distance). This metric
          is aimed at trying to see if the model is better or worse at
          predicting the distance at different parts of the output.
        - A video comparing actual and predicted distance.

    Args:
        ds: the dataset to operate over. The stride of the dataset must
            be set before calling this method.
        batch_size: passed on to forward() call. Used to create a
            dataloader.
        num_workers: passed on to forward() call. Used to create
            a dataloader.
        pin_memory: passed on to forward() call. Used to create
            a dataloader.
    """
    label_prefix = label_prefix if label_prefix else ""
    N_CHUNK = 3
    df = pl.DataFrame(
        data=[
            (
                dist_actual.tolist(),
                dist_pred.tolist(),
                spikes_actual.tolist(),
                spikes_pred.tolist(),
            )
        ],
        schema=["d_actual", "d_pred", "s_actual", "s_pred"],
    )

    def van_rossum(row, τ_ms):
        return kdtpp.metrics.van_rossum(
            row["s_actual"], row["s_pred"], bin_ms, τ_ms
        )

    def pcorr(row, num_bins):
        return kdtpp.metrics.binned_pcorr(
            row["s_actual"], row["s_pred"], num_bins
        )

    def schreiber(row, σ_ms):
        return kdtpp.metrics.schreiber(
            row["s_actual"], row["s_pred"], bin_ms, σ_ms
        )

    def chuck_pcorr(row, N, i):
        chunk_actual = np.array_split(row["d_actual"], N)[i]
        chunk_pred = np.array_split(row["d_pred"], N)[i]
        res = kdtpp.metrics.pcorr(chunk_actual, chunk_pred)
        return res

    stats_df = df.with_columns(
        [
            # Distance correlation. Should this be done in stats fn?
            pl.struct(["d_actual", "d_pred"])
            .apply(lambda x: kdtpp.metrics.pcorr(x["d_actual"], x["d_pred"]))
            .alias("distf_pcorr"),
            # For seeing if there is change over time.
            # Chunked distance correlation.
            *(
                pl.struct(["d_actual", "d_pred"])
                .apply(functools.partial(chuck_pcorr, N=N_CHUNK, i=i))
                .alias(f"chunk_distf_pcorr-chunk{i}of{N_CHUNK}")
                for i in range(N_CHUNK)
            ),
            # MSE
            pl.struct(["d_actual", "d_pred"])
            .apply(lambda x: kdtpp.metrics.mse(x["d_actual"], x["d_pred"]))
            .alias("distf_mse"),
            pl.struct(["s_actual", "s_pred"])
            .apply(lambda x: kdtpp.metrics.pcorr(x["s_actual"], x["s_pred"]))
            .alias(f"pcorr-{round(bin_ms)}_ms"),
            *(
                pl.struct(["s_actual", "s_pred"])
                .apply(functools.partial(pcorr, num_bins=b))
                .alias(f"pcorr-σ{b}_ms")
                for b in [round(s / bin_ms) for s in (5, 10, 20, 40, 80)]
            ),
            *(
                pl.struct(["s_actual", "s_pred"])
                .apply(functools.partial(schreiber, σ_ms=s))
                .alias(f"schreiber-σ{s}_ms")
                for s in (1, 2, 5, 10)
            ),
            *(
                pl.struct(["s_actual", "s_pred"])
                .apply(functools.partial(van_rossum, τ_ms=s))
                .alias(f"vrossum-τ{s}_ms")
                for s in (25, 50, 100, 200)
            ),
        ]
    ).select(pl.all().exclude(["s_actual", "s_pred", "d_actual", "d_pred"]))
    stats = stats_df.to_dicts()
    assert len(stats) == 1, "Only 1 row!"
    metrics = [
        kdai._logging.Metric(f"{label_prefix}{k}", v)
        for k, v in stats[0].items()
    ]
    return metrics
