import pytest
import math
import numpy as np
import kdtpp.datasets as ds
import kdtpp.experiments as exp
import kdtpp.mea
import kdai.train
import polars as pl
import logging

_logger = logging.getLogger(__name__)


@pytest.mark.skip
def test_epoch_len_persistent_workers():
    """An issue was encountered whereby the Classic dataset Stack Overflow
    would iterate past the number of samples in the dataset when consuming
    via a DataLoader. This test is to find the problem and then keep it fixed.

    Skip: persistent_workers bug was the cause of the issue. Just disabling
    it.
    """

    run_spec, ds_mgr = next(
        iter(exp.Classic.for_model_and_ds("gpt-2-4-16-const", "so-badges"))
    )
    trainable = exp.trainable_fns[run_spec.model_name](ds_mgr, "train-info")

    train_dl_fn, _ = kdai.train._create_dataloaders(
        batch_size=run_spec.batch_size,
        eval_batch_size=run_spec.batch_size,
        n_workers=9,
        pin_memory=False,
        persistent_workers=True,
        samples_per_epoch=None,
    )

    train_dl = trainable.train_dl(train_dl_fn)
    assert len(train_dl) == 14976

    for epoch in range(run_spec.n_epochs):
        _logger.info(f"Epoch {epoch+1}/{run_spec.n_epochs}")
        for batch_step, sample in enumerate(train_dl):
            if batch_step >= len(train_dl):
                import pdb

                pdb.set_trace()
                raise ValueError(
                    "Batch step exceeds length of DataLoader."
                    f"batch_step: {batch_step}, len(train_dl): {len(train_dl)}"
                )


# Tag as slow, as it takes a while to run.
@pytest.mark.slow
def test_epoch_len(tmp_path):
    """
    Same as test_epoch_len(), but using train.train().
    """

    run_spec, ds_mgr = next(
        iter(exp.Classic.for_model_and_ds("gpt-2-4-16-const", "so-badges"))
    )
    trainable = exp.trainable_fns[run_spec.model_name](ds_mgr, "train-info")
    train_kwargs = exp.get_train_args(run_spec.ds_name, run_spec.model_name)
    train_kwargs["n_epochs"] = run_spec.n_epochs
    train_kwargs["batch_size"] = run_spec.batch_size
    train_kwargs["steps_til_eval"] = run_spec.steps_til_eval
    train_kwargs["n_workers"] = 9

    # Should be no exception.
    kdai.train.train(trainable, out_dir=tmp_path, **train_kwargs)


def test_per_rec_stats(rec):
    # Setup.
    downsample = 18
    rec_dc = kdtpp.mea.decompress_recording(rec, downsample)
    split_ratio = (7, 2, 1)
    train_segs, val_segs, test_segs = kdtpp.mea.mirror_split(
        rec_dc, split_ratio
    )
    t_untils = np.concatenate([r.time_until_spike() for r in train_segs])
    t_untils = t_untils[t_untils > 0]
    dts = np.concatenate(
        [
            np.diff(kdtpp.mea.compress_spikes(r.spikes[:, c]))
            for r in train_segs
            for c in range(r.num_cells())
        ]
    )
    actual_dt_mean = np.mean(dts)
    actual_log_dt_mean = np.mean(np.log(dts))

    # mean_dts ~= 2 * mean_t_untils (triangle)
    # Test.
    stats = exp.per_rec_stats(rec, split_ratio, downsample=downsample)
    (
        dt_mean,
        log_dt_mean,
        t_until_mean,
        t_until_sd,
        cell_dNt_mean,
        cell_dNt_min,
        cell_dNt_max,
    ) = stats
    assert math.isclose(t_until_mean, np.mean(t_untils), rel_tol=1e-1)
    assert math.isclose(t_until_sd, np.std(t_untils), rel_tol=1e-1)
    assert math.isclose(dt_mean, actual_dt_mean, rel_tol=1e-1)
    assert math.isclose(log_dt_mean, actual_log_dt_mean, rel_tol=1e-2)
    _logger.warning("Try reduce the tolerances for these tests.")

    # assert math.isclose(log_dt_mean, actual_log_dt_mean, rel_tol=1e-5)
    assert set(cell_dNt_min.keys()) == set(rec.cell_ids)
    assert set(cell_dNt_max.keys()) == set(rec.cell_ids)
    assert set(cell_dNt_mean.keys()) == set(rec.cell_ids)
