import pytest
import math
import numpy as np
import kdtpp.datasets as ds
import polars as pl
import logging

_logger = logging.getLogger(__name__)


def test_epoch_opts():
    """
    Tests a number of known input-output pairs.
    """
    # Setup
    n_samples = 2**27
    max_batch_size = 2048
    min_steps_per_epoch = 128
    max_epochs = 512
    train_lens = [2**n for n in range(10, 28)]
    batch_size = [8, 16, 32, 64, 128, 256, 512, 1024, 2048] + [2048] * 9
    n_epochs = [512] * 9 + [256, 128, 64, 32, 16, 8, 4, 2, 1]
    # Test
    for tl, expected_n_b, expected_n_e in zip(train_lens, batch_size, n_epochs):
        n_e, n_b = ds.epoch_opts(
            tl, n_samples, max_epochs, max_batch_size, min_steps_per_epoch
        )
        assert n_b == expected_n_b
        assert n_e == expected_n_e


def test_trim_to(np_rng):
    """
    Tests that the function correctly trims the sequences.

    Tests 2 known case, and some random cases.

    More details:
        1. Test various trim lengths for a fixed list of sequences.
        2. Test edge cases associated with empty sequences in the list.
        3. Random list length with random sequences and random request lengths.

    """
    # 1. Known case
    # Setup
    lens = [10, 7, 5, 3, 4, 12]
    #  Σ = [10, 17, 22, 25, 29, 41]
    seqs = [np.ones(l) for l in lens]
    request_response = [
        [1, [1]],
        [2, [2]],
        [10, [10]],
        [12, [10, 2]],
        [16, [10, 6]],
        [17, [10, 7]],
        [40, [10, 7, 5, 3, 4, 11]],
        [41, [10, 7, 5, 3, 4, 12]],
    ]
    error_cases = [-1, 0, 42]

    # Test
    # No exceptions
    for req, resp in request_response:
        trimmed_seqs = ds.trim_to(seqs, req)
        assert [len(s) for s in trimmed_seqs] == resp
    # Should raise
    for ec in error_cases:
        with pytest.raises(ValueError):
            ds.trim_to(seqs, ec)

    # 2. Known cases associated with empty lists.
    # Some of these cases can be relaxed from errors if a usecase is found.
    # You must provide sequences, even if you request zero.
    with pytest.raises(ValueError):
        ds.trim_to([], 0)
    with pytest.raises(ValueError):
        ds.trim_to([[]], 0)
    # Furthermore, because no sequences are allowed to be empty:
    with pytest.raises(ValueError):
        assert ds.trim_to([[1, 1, 1], [1, 1, 1], []], 2)
    with pytest.raises(ValueError):
        assert ds.trim_to([[1, 1, 1], [], [1, 1, 1]], 2)
    # It is not even okay to request zero events.
    with pytest.raises(ValueError):
        ds.trim_to([[1, 1, 1], [1, 1]], 0)

    # 2. Random cases
    # Setup
    n_runs = 100
    MAX_N_SEQS = 2000
    MAX_LEN = 100
    MAX_ELEM_VAL = int(1e6)  # Actual values shouldn't really matter.
    n_requests = 40
    for i in range(n_runs):
        n_seqs = np_rng.integers(1, MAX_N_SEQS)
        seqs = [
            np_rng.integers(1, MAX_ELEM_VAL, size=np_rng.integers(1, MAX_LEN))
            for _ in range(n_seqs)
        ]
        n_elements = sum(len(s) for s in seqs)
        requests = np_rng.integers(1, n_elements, size=n_requests)
        _logger.debug(f"{n_seqs=}, {n_elements=}, {requests=}")
        for req in requests:
            res = ds.trim_to(seqs, req)
            assert sum(len(s) for s in res) == req
        with pytest.raises(ValueError):
            ds.trim_to(seqs, n_elements + 1)
        with pytest.raises(ValueError):
            ds.trim_to(seqs, 0)


def test_gen_poisson(np_rng):
    """
    Tests that:
        1. The call completes without error.
        2. The mean log probs approach the analytic entropy of the exponential
            distribution.
    """
    # Setup
    n = 100000
    # entropy(exp_mu) = 1 - log(mu)
    analytic_entropy = 1
    # Test
    ts, log_probs = ds.gen_poisson(n, np_rng)
    neg_mean_log_prob = -np.mean(log_probs)
    assert math.isclose(neg_mean_log_prob, analytic_entropy, abs_tol=0.01)


def test_gen_nonstationary_poisson(np_rng):
    """
    Tests that:
        1. The call completes without error.
        2. For a very slow oscillation, the mean log probs approach the analytic
            entropy of the exponential distribution when starting from peak
            (which will be 1/4 of period with no offset, c).
        3. At a period of 1, the first event has the correct probability.
        4. For large n at the default period, the mean log prob is as expected.
    """
    n = 100000

    # Test 1 & 2
    ## Setup
    period = int(1e10)
    λ = 2
    analytic_entropy = 1 - math.log(λ)
    ## Test
    ts, log_probs = ds.gen_nonstationary_poisson(
        n, period, np_rng, c=period / 4
    )
    neg_mean_log_prob = -np.mean(log_probs)
    assert math.isclose(neg_mean_log_prob, analytic_entropy, abs_tol=0.01)

    # Test 3
    period = 1
    ts, log_probs = ds.gen_nonstationary_poisson(n, period, rng=np_rng, c=0)
    λ = math.sin(2 * math.pi * ts[0] / period) + 1
    Λ = (
        -period / (2 * math.pi) * (math.cos(2 * math.pi * ts[0] / period) - 1)
        + ts[0]
    )
    log_p_0 = math.log(λ) - Λ
    assert math.isclose(log_probs[0], log_p_0, abs_tol=0.01)

    # Test 4
    ts, log_probs = ds.gen_nonstationary_poisson(n, rng=np_rng)
    # This value was calculated by running it once. It might be incorrect, and
    # so the test more acts as a check that the behaviour doesn't change.
    anticipated_entropy = 0.69339
    assert math.isclose(-np.mean(log_probs), anticipated_entropy, abs_tol=0.01)


def test_gen_hawkes1():
    # Test that the call completes without error.
    # This was added after some zero time deltas were observed.
    ds.gen_hawkes1(n_events=2**25)
    ds.gen_hawkes1(n_events=2**24)
