import pytest
import torch
import kdtpp.trainables
import kdtpp.datasets as ds
import math
import scipy
import numpy as np
from itertools import product


def gen_clipped_poisson(n_events, max_val, rng):
    dts = scipy.stats.expon.rvs(loc=0, scale=1, size=n_events, random_state=rng)
    dts = np.clip(dts, 0, max_val)
    log_probs = scipy.stats.expon.logpdf(dts, loc=0, scale=1)
    return dts, log_probs


def poisson_proc_ds_mgr(n_events, max_val, rng):
    dts, log_probs = gen_clipped_poisson(n_events, max_val, rng)
    ds_mgr = ds.RandProcessDatasets(dts, log_probs, model_in_len=5)
    return ds_mgr


def binomial_proc_ds_mgr(n_events, max_val, rng):
    n = max_val - 1
    dts = scipy.stats.binom.rvs(n=n, p=0.5, size=n_events, random_state=rng)
    log_probs = scipy.stats.binom.logpmf(dts, n=n, p=0.5)
    dts = dts + 1
    ds_mgr = ds.RandProcessDatasets(dts, log_probs, model_in_len=5)
    return ds_mgr


class DummyModel:
    def __init__(self, out_resolution):
        self.out_resolution = out_resolution

    def forward(self, x, mask):
        b, s, c = x.shape
        assert c == 1
        res = torch.ones((b, s, self.out_resolution)).cuda()
        return res

    def __call__(self, x, mask):
        return self.forward(x, mask)


def test_DiscreteTrainable_eval_metrics_mass(np_rng):
    """
    When bins are samples of a continuous density, the number of bins in the
    distribution don't matter; really, the bins are not bins but
    samples—samples of a continuous distribution. Integrating over the
    distribution can be done numerically.

    Tests that:
       1. metrics are calculated correctly for a few diffent settings.
       2. an error is thrown if a non-integer interval length is used. Discrete
          distributions only support summing integer number of bins.
    """

    def create_trainable(n_bins):
        # We need to clip values to the maximum bin value, otherwise we will
        # get infinite nll.
        ds_mgr = binomial_proc_ds_mgr(
            n_events=int(1e5), max_val=n_bins, rng=np_rng
        )
        m = DummyModel(n_bins)
        trainable = kdtpp.trainables.DiscreteTrainable(
            ds_mgr,
            m,
            "uniform-model",
            bin_mode="mass",
            causal=True,
        )
        dl = torch.utils.data.DataLoader(
            ds_mgr.train_ds(), batch_size=1024, num_workers=3, drop_last=True
        )
        return trainable, dl

    # Model outputs uniform distribution over 100 bins, 1, 2, 3, ... 100.
    # So we expect:
    #   - loss:
    #   - pred_nll: log(100)
    #   - mean_abs_err (median): Σ 1/100 * |
    #   - mean_abs_err_mode:
    #   - mean_abs_err_mean:     Σ
    #   - interval_pred_nll: log(100) - log(interval_len)
    n_bins = [5, 10]  # , 13, 100, 101, int(1e5)]
    for b in n_bins:
        tr, dl = create_trainable(b)
        # mass mode only supports integer interval lengths. The lengths are not
        # used in the probability calculation.
        tr.ds_mgr.density_interval_len = 1
        res = tr.eval_metrics(dl)
        print(res)
        # 1. pred_nll
        assert res["pred_nll"] == pytest.approx(math.log(b), abs=1e-6)
        # 2. interval_pred_nll
        # Should be the same as pred_nll, by convection, as the interval prob
        # only really makes sense for density mode.
        assert res["interval_pred_nll"] == pytest.approx(
            res["pred_nll"], abs=1e-4
        ), f"{b=}"
        # 3. mean_abs_err (mode)
        # Uniform dist will have argmax evaluate to 0th element, which
        # corresponds to the first bin.
        first_bin = 1.0  # always
        expected_mode_err = np.abs(first_bin - tr.ds_mgr.train_data).mean()
        assert res["mean_abs_err_mode"] == pytest.approx(
            expected_mode_err, abs=1e-3
        ), f"{b=} "

        # 2. An error is thrown if the queried interval doesn't match the bin
        # length.
        tr.ds_mgr.density_interval_len = 0.5
        with pytest.raises(ValueError):
            tr.eval_metrics(dl)
        tr.ds_mgr.density_interval_len = 2
        with pytest.raises(ValueError):
            tr.eval_metrics(dl)
        tr.bin_width = 2
        tr.ds_mgr.density_interval_len = 1
        with pytest.raises(ValueError):
            tr.eval_metrics(dl)


def test_DiscreteTrainable_eval_metrics_density(np_rng):
    """
    For the continuous case, the number of bins in the distribution don't
    matter; really, the bins are not bins but samples—samples of a continuous
    distribution. Integrating over the distribution can be done numerically.
    """
    _N = int(1e6)
    # # For some stats calculations within the test.
    # poisson_samples = scipy.stats.expon.rvs(
    #     loc=0, scale=1, size=_N, random_state=np_rng
    # )

    def create_trainable(n_bins):
        # We need to clip values to the maximum bin value, otherwise we will
        # get infinite nll.
        ds_mgr = poisson_proc_ds_mgr(
            n_events=int(1e5), max_val=n_bins, rng=np_rng
        )
        m = DummyModel(n_bins)
        trainable = kdtpp.trainables.DiscreteTrainable(
            ds_mgr,
            m,
            "uniform-model",
            bin_mode="density",
            causal=True,
        )
        dl = torch.utils.data.DataLoader(
            ds_mgr.train_ds(), batch_size=1024, num_workers=3, drop_last=True
        )
        return trainable, dl

    n_bins = [5, 10]  # , 13, 100, 101, int(1e5)]
    interval_lens = [0.5, 1, 1.3, 2, 20]
    for b, ilen in product(n_bins, interval_lens):
        tr, dl = create_trainable(b)
        tr.ds_mgr.density_interval_len = ilen
        res = tr.eval_metrics(dl)
        print(res)
        # 1. pred_nll
        # The model outputs 1/b, with is a probability mass spread over the
        # interval of length tr.bin_width. This makes the density:
        #  p(x) = 1/b / tr.bin_width
        # nll(x) = -log(p(x)) = -log(1/b) + log(tr.bin_width)
        assert res["pred_nll"] == pytest.approx(
            math.log(b) + math.log(tr.bin_width), abs=1e-6
        ), f"{b=}, {ilen=}"
        # 2. interval_pred_nll
        # If we double the queried interval length, we expect the
        # probability covered by the uniform distribution to double.
        expected_interval_nlprob = max(0, math.log(b) - math.log(ilen))
        # It will be a bit lower due to queries that go partially over the
        # end bins. So reduce abs tol.
        assert res["interval_pred_nll"] == pytest.approx(
            expected_interval_nlprob, abs=1e-2
        ), f"{b=}, {ilen=}"
        # 3. mean_abs_err (mode)
        # Uniform dist will have argmax evaluate to 0th element, which
        # corresponds to the first bin.
        bin_len = 0.5  # always
        expected_mode_err = np.abs(bin_len - tr.ds_mgr.train_data).mean()
        assert res["mean_abs_err_mode"] == pytest.approx(
            expected_mode_err, abs=1e-3
        ), f"{b=}, {ilen=}"


def test_VarBinTrainable_edges_from_quantiles():
    """A test from a failure case."""
    # fmt: off
    quantiles = [
        2.0, 3.85, 4.18333333, 5.0, 5.0, 5.73333333, 6.0, 6.0, 6.5, 7.0, 7.0, 
        7.0, 7.45, 7.9, 8.0, 8.0, 8.13333333, 8.55, 8.96666667, 9.0, 9.0, 9.1,
        9.5, 9.88333333, 10.0, 10.0, 10.0, 10.35, 10.73333333, 11.0, 11.0, 11.0,
        11.18333333, 11.55, 11.93333333, 12.0, 12.0, 12.0, 12.38333333, 
        12.76666667, 13.0, 13.0, 13.0, 13.25, 13.63333333, 14.0, 14.0, 14.0,
        14.18333333, 14.58333333, 15.0, 15.0, 15.0, 15.2, 15.61666667, 16.0, 
        16.0, 16.0, 16.33333333, 16.78333333, 17.0, 17.0, 17.13333333, 
        17.61666667, 18.0, 18.0, 18.06666667, 18.58333333, 19.0, 19.0, 19.15,
        19.71666667, 20.0, 20.0, 20.41666667, 21.0, 21.0, 21.26666667, 
        21.91666667, 22.0, 22.25, 22.98333333, 23.0, 23.43333333, 24.0, 
        24.01666667, 24.86666667, 25.0, 25.61666667, 26.0, 26.53333333, 27.0,
        27.61666667, 28.0, 28.93333333, 29.16666667, 30.0, 30.91666667, 
        31.38333333, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.41666667, 40.0,
        41.21666667, 43.0, 45.0, 47.0, 49.3, 52.0, 55.0, 59.0, 63.0, 68.0,
        74.03333333, 82.0, 92.0, 105.75, 125.0, 154.93333333, 210.51666667,
        316.75, 478.0, 1016.28158762,
    ]
    # fmt: on
    # Test. Shouldn't assert.
    kdtpp.trainables.VarBinTrainable.bin_edges_from_quantiles(quantiles)
