import math
import logging
import kdtpp.mea as mea
import numpy as np
import einops
import torch
import pdb
import kdtpp.spikedistance as sdf
import kdtpp.mea
import kdtpp.metrics
import kdai._logging
from typing import Sequence, Optional, Literal, Dict
from collections import defaultdict
from dataclasses import dataclass
from tqdm import tqdm
import concurrent.futures
from functools import partial

UNSET = -1

_logger = logging.getLogger(__name__)


def y_to_t_range(y, win_len):
    # t_min < 0 means there isn't any spike in the rest of the recording.
    # We don't know when the next spike will be, but it is after the
    # prediction window.
    in_pred_window = (y >= 0) & (y < win_len)
    t_min = torch.where(in_pred_window, y, win_len)
    t_max = torch.where(in_pred_window, y + 1, float("inf"))
    return t_min, t_max


def to_autoreg_sample(t_idx, end, in_len, recording, dts, pred_spikes, stride):
    """Create an input tuple with spikes used from previous predictions.

    Only supports 1 cell at a time.
    """
    assert t_idx < end, t_idx
    T, C = recording.spikes.shape
    assert C == 1, C
    assert T == len(recording)
    in_slice = slice(t_idx, t_idx + in_len)
    stim = recording.stimulus[in_slice].astype(np.float32)
    # If we were not using autoregressive:
    # spikes = recording.spikes[in_slice].astype(np.float32)
    spikes = pred_spikes[in_slice].astype(np.float32)
    assert np.all(spikes != UNSET) and spikes.shape == (in_len,)
    spikes = einops.rearrange(spikes, "t -> 1 1 t", t=in_len)
    stim = einops.rearrange(stim, "t c -> 1 c t", c=4, t=in_len)
    # the last element of the sample stores the number of timesteps until the
    # next spike (excluding itself). We want this number minus 1, as we will
    # be shifting our reference point forward by 1 when considering the output.
    time_until_spike = torch.tensor(dts[in_slice.stop - 1] - 1)
    t_min, t_max = y_to_t_range(time_until_spike, stride)
    x = np.concatenate((stim, spikes), axis=1)
    t_min = torch.tensor([t_min], dtype=torch.float32)
    t_max = torch.tensor([t_max], dtype=torch.float32)
    # To (batch, channels, time)
    x = torch.from_numpy(x).contiguous()
    cell_id = torch.tensor([recording.cell_ids[0]], dtype=torch.int32)
    return x, cell_id, t_min, t_max


def jobs_remaining(t_lhs, in_len, stride, rec_len):
    end_win = t_lhs + in_len + stride
    is_finished = end_win >= rec_len
    # nonzero returns a tuple: 1 array per dimension (just 1 in this case).
    n_jobs = sum(~is_finished)
    return n_jobs


def to_autoreg_sample2(t_lhs, in_len, recording, dts, pred_spikes, stride):
    """Create an input tuple with spikes used from previous predictions.

    For multiple cells per forward.
    """
    T, C = recording.spikes.shape
    assert t_lhs.shape == (C,)
    assert T == len(recording)
    end_win = t_lhs + in_len + stride
    is_finished = end_win >= T
    # nonzero returns a tuple: 1 array per dimension (just 1 in this case).
    job_idxs = (~is_finished).nonzero()[0]
    assert len(job_idxs)
    B = len(job_idxs)  # batch size.
    # Anything used as an index must be int.
    in_start = t_lhs[job_idxs].astype(int)
    in_end = in_start + in_len
    assert in_start.shape == in_end.shape == (B,)

    # Create a batch of input stimuli.
    stim_select = (np.arange(in_len) + in_start[:, None]).astype(int)
    assert stim_select.shape == (B, in_len)
    stim = recording.stimulus[stim_select, :].astype(np.float32)
    assert stim.shape == (B, in_len, recording.stimulus.shape[1])

    # Similar for spikes
    spikes = pred_spikes[stim_select, job_idxs[:, None]].astype(np.float32)
    assert spikes.shape == (B, in_len)
    assert np.all(spikes != UNSET)

    # Stack to make input, x.
    stim = einops.rearrange(stim, "b t c -> b c t", b=B, c=4, t=in_len)
    spikes = einops.rearrange(spikes, "b t -> b 1 t", b=B, t=in_len)
    x = np.concatenate((stim, spikes), axis=1)

    # Get the target (used for likelihood calculation)
    # the last element of the sample stores the number of timesteps until the
    # next spike (excluding itself). We want this number minus 1, as we will
    # be shifting our reference point forward by 1 when considering the output.
    # If this doesn't make sense, remember that we have discrete bins that
    # are indexed starting at zero. The earliest a spike can be is in the 0th
    # bin, which corresponds to the interval (0, 1ms] .
    time_until_spike = dts[in_end - 1, job_idxs] - 1
    assert time_until_spike.shape == (B,), time_until_spike.shape

    # Get the interval in which the spike falls, e.g. [0, 1) or [80, \infty)
    # #Torch from here.
    t_min, t_max = y_to_t_range(
        torch.from_numpy(time_until_spike).float(), stride
    )
    assert torch.all((t_min >= 0) & (t_max > t_min)), (t_min, t_max)
    assert t_min.shape == t_max.shape == (B,)
    # To (batch, channels, time)
    x = torch.from_numpy(x).contiguous()
    cell_id = torch.tensor(recording.cell_ids)[job_idxs].float()
    # job_idxs are cell_idxs. Could rename.
    return job_idxs, x, cell_id, t_min, t_max


@torch.no_grad()
def prob_auto_infer(trainable, recording: mea.SpikeRecording, stride):
    """Generates a spike train autoregressively from probability based model.

    Works with SpikesDiscrete and LogMixForSpikes. Can work with any trainable
    that has a median_and_ll() method, which calculates the median prediction
    and the log-likelihood of the ground truth wrt the model output.

    Works per recording, as gaps between recordings prevent output spike trains
    from being concatenated.
    """
    # We need to do manual steps, as next-spike prediction models can't step
    # further than their next predicted spike.
    if recording.num_cells() > 1:
        raise ValueError(f"Only one cell at a time. {recording.num_cells()=}")
    dts = recording.time_until_spike()
    in_len = trainable.model_in_len
    if hasattr(trainable, "model_out_len"):
        if stride > trainable.model_out_len:
            raise ValueError("Stride must be less than output length.")
    snippet_len = in_len + stride
    # 1 past the last usable start timestep.
    end = len(recording) - snippet_len + 1
    # The current lhs of the snippet (oldest input timestep).
    pred_spikes = np.full(len(recording), fill_value=UNSET, dtype=int)
    # For the initial segment, copy the spikes from the recording.
    pred_spikes[:in_len] = recording.spikes[:in_len][0]

    lls = []
    t_lhs = 0
    while t_lhs < end:
        end_win = t_lhs + snippet_len
        x, cell_id, t_min, t_max = to_autoreg_sample(
            t_lhs, end, in_len, recording, dts, pred_spikes, stride
        )
        if t_min < 0:
            # This occurs when there is no further spike in the recording, and
            # t_min counts (in negative values) the timesteps until the end
            # of the recording. We could handle this end case by integrating
            # over the relevant range; however, there is no shortage of data
            # and so the entry code branch and the bugs it might introduce
            # don't seem worth it. So, just skip the ll calculation. We can
            # still continue to generate a spike train though.
            median = trainable.infer_median(x, cell_id)
        else:
            median, ll = trainable.median_and_ll(x, cell_id, t_min, t_max)
            lls.append(ll.item())
        # median = trainable.infer_median(x, cell_id).item()
        # Here we convert from float to int. This is a decision to keep the
        # 1ms bin structure of the spikes the same for the output.
        # We floor, as bin zero covers [0, 1).
        median = math.floor(median)
        # First, clear output to zero (no spikes)
        pred_spikes[t_lhs + in_len : end_win] = 0
        if median < stride:
            # We will use this spike.
            spike_idx = t_lhs + in_len + median
            pred_spikes[spike_idx] = 1
            # Move the window forward.
            t_lhs = spike_idx - in_len + 1
            assert pred_spikes[t_lhs + in_len - 1] == 1, "Just put a spike here"
        else:
            # No spikes. Move forward by stride.
            t_lhs += stride

    # Set the initial input back to UNSET
    pred_spikes[:in_len] = UNSET
    pred_end = end_win
    pred_start = in_len
    assert np.all(pred_spikes[:pred_start] == UNSET), pdb.set_trace()
    assert np.all(pred_spikes[in_len:pred_end] != UNSET)
    assert np.all(pred_spikes[pred_end:] == UNSET)
    pred_spikes = pred_spikes[pred_start:pred_end]
    lls = np.array(lls)
    return pred_spikes, pred_start, pred_end, lls


def prob_auto_infer2(
    trainable,
    recording: mea.SpikeRecording,
    stride,
    cell_dt_min,
    cell_max,
    rate_lims,
    rate_wins,
):
    """Generates a spike train autoregressively from probability based model.

    Parallel version of prob_auto_infer (along a batch dimension).

    Works with SpikesDiscrete and LogMixForSpikes. Can work with any trainable
    that has a median_and_ll() method, which calculates the median prediction
    and the log-likelihood of the ground truth wrt the model output.

    Works per recording, as gaps between recordings prevent output spike trains
    from being concatenated.
    """
    # We need to do manual steps, as next-spike prediction models can't step
    # further than their next predicted spike.
    dts = recording.time_until_spike()
    in_len = trainable.model_in_len
    if hasattr(trainable, "model_out_len"):
        if stride > trainable.model_out_len:
            raise ValueError("Stride must be less than output length.")
    snippet_len = in_len + stride
    n_cells = recording.num_cells()
    # pred_spikes will be populated and returned as a result.
    pred_spikes = np.zeros((len(recording), n_cells), dtype=int)
    # For the initial segment, copy the spikes from the recording.
    pred_spikes[:in_len, :] = recording.spikes[:in_len]

    # likelihood arrays can be different lengths per cell, so list of lists.
    lls = [[] for _ in range(n_cells)]
    # LHS marks the oldest element being used as input to the forward.
    # This will be moved forward, independently for each cell.
    t_lhs = np.zeros(n_cells)
    to_npy = lambda t: t.cpu().numpy()
    last_spikes = np.full(n_cells, fill_value=0)
    n_last = cell_dt_min.shape[1]
    last_n_query = np.concatenate(
        [
            np.array([np.ones(n_cells, dtype=int) * i for i in range(1, n_last + 1)]).T,
            rate_lims,  # max spikes in 50 bins
        ],
        axis=1,
    )
    while jobs_remaining(t_lhs, in_len, stride, len(recording)):
        job_idxs, x, cell_id, t_min, t_max = to_autoreg_sample2(
            t_lhs, in_len, recording, dts, pred_spikes, stride
        )

        median, ll = trainable.median_and_ll(x, cell_id, t_min, t_max)
        # Likelihoods
        # Only add likelihoods if they were calculated.
        median, ll = to_npy(median), to_npy(ll)
        for idx in range(len(job_idxs)):
            lls[idx].append(ll[idx])

        # Spike prediction
        # Convert from float to int. This is a decision to keep the 1ms bin
        # structure of the spikes the same for the output. We floor, as bin
        # zero covers [0, 1).
        # Clip to stride to prevent out of bounds indexing.
        # Ideally, we wouldn't clip here. but we can't make an actual branch 
        # based on the median value.
        pred = np.clip(np.floor(median).astype(int), a_min=0, a_max=stride)
        # First, clear output to zero (no spikes)
        # pred_start = t_lhs[job_idxs] + in_len
        # Instead of setting each segment, we will just initialize all to
        # zero.
        # pred_select = (np.arange(stride) + pred_start[:, None]).astype(int)
        # assert pred_select.shape == (len(job_idxs), stride)
        # pred_spikes[pred_select, job_idxs[:, None]] = 0

        # Add a spike if it's in the prediction window.
        pred_begin = (t_lhs[job_idxs] + in_len).astype(int)
        if cell_max is not None:
            max_exceeded = (
                t_lhs[job_idxs] + in_len - last_spikes[job_idxs]
                > cell_max[job_idxs]
            )
            # Use prior to maintain close to minimum ISI.
            pred = np.where(max_exceeded & (pred >= stride), 0, pred)
        spike_idx = (pred_begin + pred).astype(int)
        pred_spikes[spike_idx, job_idxs] = np.where(pred < stride, 1, 0)
        # Move up to spike or forward by stride (if no spike). 
        move_up_by = np.where(pred < stride, pred + 1, stride)
        t_lhs_for_jobs = t_lhs[job_idxs] + move_up_by
        last_spikes[job_idxs] = np.where(
            pred < stride, spike_idx, last_spikes[job_idxs])

        # Check rate and ISI constraints.
        # Move forward again to enforce minimum ISIs and maximum rates.
        # This can be generalized to handle N levels of ISI and M rate windows,
        # but currently, it's hardcoded to ISI(0, -1) and ISI(0,-2) and one
        # window of 50 bins.
        # Use the largest rate window for selection.
        W = np.max(rate_wins) 
        win_s = t_lhs_for_jobs + in_len - W
        win_select = (np.arange(W) + win_s[:, None]).astype(int)
        # If there was a spike, it will be the last element in the window.
        win = pred_spikes[win_select, job_idxs[:, None]]
        last_ns = sdf.last_n_idx(win, last_n_query[job_idxs])

        move_up_on_spike = np.maximum(
            0,
            np.maximum(
                (cell_dt_min[job_idxs] - last_ns[:, :n_last]).max(axis=1),
                # Inforce maximum rates.
                (rate_wins[job_idxs] - last_ns[:, n_last:]).max(axis=1),
                # Optional minimum constraints.
                # np.maximum(
                #     # At least the minimum ISI.
                #     cell_d1t_min[job_idxs] - last_ns[:, 0],
                #     # At least the minimum ISI between 2nd neighbours.
                #     cell_d2t_min[job_idxs] - last_ns[:, 1],
                # ),
            ),
        )
        t_lhs_for_jobs += move_up_on_spike
        t_lhs[job_idxs] = t_lhs_for_jobs

    # Set the initial input back to UNSET
    pred_spikes[:in_len] = UNSET
    # and the segment end.
    pred_ends = (t_lhs + snippet_len).astype(int)
    for i in range(n_cells):
        pred_spikes[pred_ends[i] :, i] = UNSET
    pred_start = in_len
    assert np.all(pred_spikes[:pred_start] == UNSET), pdb.set_trace()
    # TODO: vectorize these checks.
    # assert np.all(pred_spikes[in_len:pred_end] != UNSET)
    # assert np.all(pred_spikes[pred_end:] == UNSET)
    lls = [np.array(a) for a in lls]
    return pred_spikes, pred_start, pred_ends, lls


@dataclass
class AutoStatsResult:
    van_rossum: float
    pcorr: float
    schreiber: float
    ll: Optional[float]
    n_pred: int
    n_gt: int
    gt_spikes: np.ndarray
    pred_spikes: np.ndarray
    pred_start: int
    pred_end: int
    cid: Optional[int] = None  # used in parallel version.


@torch.no_grad()
def auto_stats(
    trainable, recording: mea.SpikeRecording, stride, bin_ms, sigma_ms
) -> AutoStatsResult:
    if type(trainable).__name__ in ["SpikesDiscrete", "LogMixForSpikes"]:
        pred, s, e, ll = prob_auto_infer(trainable, recording, stride)
        ll = ll.mean().item()
    else:
        assert type(trainable).__name__ == "DistTrainable"
        pred, s, e = dist_auto_infer(trainable, recording, stride)
        ll = None
    gt_full = einops.rearrange(recording.spikes, "l 1 -> l", l=len(recording))
    gt = gt_full[s:e]
    van_rossum = kdtpp.metrics.van_rossum(
        gt, pred, bin_ms=bin_ms, tau_ms=sigma_ms
    )
    pcorr = kdtpp.metrics.smooth_pcorr(
        gt, pred, bin_ms=bin_ms, sigma_ms=sigma_ms
    )
    schreiber = kdtpp.metrics.schreiber(
        gt, pred, bin_ms=bin_ms, sigma_ms=sigma_ms
    )
    n_pred = np.sum(pred)
    n_gt = np.sum(gt)
    res = AutoStatsResult(
        van_rossum=van_rossum,
        pcorr=pcorr,
        schreiber=schreiber,
        ll=ll,
        n_pred=n_pred,
        n_gt=n_gt,
        gt_spikes=gt_full,
        pred_spikes=pred,
        pred_start=s,
        pred_end=e,
    )
    return res


# def spike_train_metrics(gt, pred, bin_ms, sigma_ms):
def worker_fn(args):
    gt, pred, bin_ms, sigma_ms = args
    van_rossum = kdtpp.metrics.van_rossum(
        gt, pred, bin_ms=bin_ms, tau_ms=sigma_ms
    )
    pcorr = kdtpp.metrics.smooth_pcorr(
        gt, pred, bin_ms=bin_ms, sigma_ms=sigma_ms
    )
    schreiber = kdtpp.metrics.schreiber(
        gt, pred, bin_ms=bin_ms, sigma_ms=sigma_ms
    )
    n_gt = np.sum(gt[gt != UNSET])
    n_pred = np.sum(pred[pred != UNSET])
    return (van_rossum, pcorr, schreiber, n_gt, n_pred)


@torch.no_grad()
def auto_stats2(
    trainable,
    recording: mea.SpikeRecording,
    stride,
    bin_ms,
    sigma_ms,
    cell_dt_min,
    cell_dt_max,
    rate_lims,
    rate_wins,
) -> Sequence[AutoStatsResult]:
    """Parallel version."""
    _logger.info(f"Inferring spike train")
    if type(trainable).__name__ in ["SpikesDiscrete", "LogMixForSpikes"]:
        # Note that in this parallel version, pred is (T, C), and maintains
        # the UNSET values so that pred can be 2D array (rather than list of
        # arrays).
        pred, start, ends, ll = prob_auto_infer2(
            trainable,
            recording,
            stride,
            cell_dt_min,
            cell_dt_max,
            rate_lims,
            rate_wins,
        )
        ll = [a.mean() for a in ll]
    else:
        assert type(trainable).__name__ == "DistTrainable"
        raise NotImplementedError()
    gt_full = recording.spikes
    _logger.info(f"Calculating spike train metrics")
    with concurrent.futures.ProcessPoolExecutor() as executor:
        stats = executor.map(
            worker_fn,
            [
                (
                    gt_full[start : ends[i], i],
                    pred[start : ends[i], i],
                    bin_ms,
                    sigma_ms,
                )
                for i in range(recording.num_cells())
            ],
        )
        stats = list(stats)
    assert len(stats) == recording.num_cells()

    res = []
    for i in range(len(stats)):
        van_rossum, pcorr, schreiber, n_gt, n_pred = stats[i]
        s = AutoStatsResult(
            van_rossum=van_rossum,
            pcorr=pcorr,
            schreiber=schreiber,
            ll=ll[i],
            n_pred=n_pred,
            n_gt=n_gt,
            gt_spikes=gt_full[:, i],
            pred_spikes=pred[:, i],
            pred_start=start,
            pred_end=ends[i],
        )
        res.append(s)
    return res


@torch.no_grad()
def rec_ll(
    trainable,
    recordings: Sequence[mea.SpikeRecording],
    pred_win_len,
    batch_size,
    n_workers,
    reduction: Literal["mean", "none"] = "mean",
    stride=1,
):
    """Calculate the non-autoregressive log-likelihood for a recording.

    Likelihood here is mass, not density.
    """
    ref_ds = trainable.train_ds()
    cell_ids = recordings[0].cell_ids
    for rec in recordings:
        if rec.cell_ids != cell_ids:
            _logger.debug(f"Detected different cell ids across recordings.")
    res_by_cid = defaultdict(list)
    ds = kdtpp.mea.NextSpikeDataset(recordings, ref_ds.input_len, stride=stride)
    dl = torch.utils.data.DataLoader(
        ds,
        batch_size=batch_size,
        num_workers=n_workers,
        shuffle=False,
        # 1 epoch, so no need for pinning.
        pin_memory=False,
    )
    total = len(dl)
    for sample in tqdm(dl, total=total):
        x, cell_id, y = sample
        # make a copy of cell_id before it is moved to GPU.
        cids = cell_id.numpy()
        B, L, _ = x.shape
        t_min, t_max = y_to_t_range(y, pred_win_len)
        interval_ll = trainable.ll(x, cell_id, t_min, t_max)
        interval_ll = interval_ll.cpu().numpy()
        for i in range(B):
            cid = cids[i]
            res_by_cid[cid].append(interval_ll[i])
    res = {}
    for cid, lls in res_by_cid.items():
        if reduction == "mean":
            res[cid] = np.array(lls).mean()
        else:
            assert reduction == "none", reduction
            res[cid] = np.array(lls)
    return res


@torch.no_grad()
def dist_auto_infer(
    trainable, recording: mea.SpikeRecording, stride, refactory=0
):
    if recording.num_cells() > 1:
        raise ValueError(f"Only one cell at a time. {recording.num_cells()=}")
    in_len = trainable.model_in_len
    # 1 past the last usable start timestep.
    end = len(recording) - in_len - stride + 1
    # The current lhs of the snippet (oldest input timestep).
    UNSET = -1
    pred_spikes = np.full(len(recording), fill_value=UNSET, dtype=int)
    # For the initial segment, copy the spikes from the recording.
    pred_spikes[:in_len] = recording.spikes[:in_len][0]

    def to_input(idx):
        assert idx < end, idx
        in_slice = slice(idx, idx + in_len)
        stim = recording.stimulus[in_slice].astype(np.float32)
        # If we were not using autoregressive:
        # spikes = recording.spikes[in_slice].astype(np.float32)
        spikes = pred_spikes[in_slice].astype(np.float32)
        assert np.all(spikes != UNSET) and spikes.shape == (in_len,)
        spikes = torch.from_numpy(
            einops.rearrange(spikes, "t -> 1 1 t", t=in_len)
        )
        stim = torch.from_numpy(
            einops.rearrange(stim, "t c -> 1 c t", c=4, t=in_len)
        )
        x = torch.cat((stim, spikes), dim=1).float().contiguous()
        cell_id = torch.tensor([recording.cell_ids[0]], dtype=torch.int32)

        # We also need LHS spike
        lhs_spike = (
            sdf.lhs_spike(
                einops.rearrange(spikes, "1 1 t -> 1 t"), trainable.max_bin_dist
            )
            .cpu()
            .item()
        )
        return x, cell_id, lhs_spike

    t_lhs = 0
    while t_lhs < end:
        x, cell_id, lhs_spike = to_input(t_lhs)
        # Supports tensors, but not batches.
        distf = trainable.infer_dist(x, cell_id)[0]
        mle_spikes = sdf.predict(
            distf,
            lhs_spike,
            trainable.max_bin_dist,
            dist_prefix_len=trainable.dist_prefix_len,
            refactory=refactory,
        )
        pred = mle_spikes[0:stride].cpu().numpy()
        out_slice = slice(t_lhs + in_len, t_lhs + in_len + stride)
        assert out_slice.stop <= len(pred_spikes), out_slice.stop
        pred_spikes[out_slice] = pred
        t_lhs += stride

    # Set the initial input back to UNSET
    last_t_lhs = t_lhs - stride
    pred_spikes[:in_len] = UNSET
    pred_end = last_t_lhs + in_len + stride
    pred_start = in_len
    assert np.all(pred_spikes[:pred_start] == UNSET), pdb.set_trace()
    assert np.all(pred_spikes[in_len:pred_end] != UNSET)
    assert np.all(pred_spikes[pred_end:] == UNSET)
    pred_spikes = pred_spikes[pred_start:pred_end]
    return pred_spikes, pred_start, pred_end
