import torch
from torch import Generator

from smlmsim.dynamics.discretize import batch_discretize_periods, discretize_periods
from smlmsim.dynamics.sample import batch_sample_periods, sample_periods
from smlmsim.dynamics.statistics import average_num_activations, average_lifespan_on


def test_max_duration(N=1000):
    time_on = 2.5
    time_off = 3.0
    time_bleach = 1.5
    gen = Generator().manual_seed(0)

    for i in range(N):
        max_duration = 5.0 * torch.rand(size=(), generator=gen)
        periods = sample_periods(
            time_on=time_on,
            time_off=time_off,
            time_bleach=time_bleach,
            max_duration=max_duration,
            gen=gen,
        )
        assert periods.max().item() <= max_duration


def test_batch_sample_periods(N=20000):
    time_on = 2.5
    time_off = 1.5
    time_bleach = 3.0
    n_frames = 10
    gen = Generator().manual_seed(0)

    x_normal = []
    for _ in range(N):
        p = sample_periods(
            time_on=time_on,
            time_off=time_off,
            time_bleach=time_bleach,
            gen=gen,
        )
        x = discretize_periods(p, n_frames=n_frames)
        x_normal.append(x)
    x_normal = torch.stack(x_normal, dim=0)
    x_normal = x_normal.mean(dim=0)

    p_batch = batch_sample_periods(
        n=N,
        time_on=time_on,
        time_off=time_off,
        time_bleach=time_bleach,
        n_periods_max=10,
        gen=gen,
    )
    x_batch = batch_discretize_periods(p_batch, n_frames=n_frames)
    x_batch = x_batch.mean(dim=0)

    torch.testing.assert_close(x_normal, x_batch, rtol=0.01, atol=0.01)


def test_average_times(N=100000):
    time_on = 2.5
    time_off = 3.0
    time_bleach = 1.5
    gen = Generator().manual_seed(0)

    theoretical_num_activations = average_num_activations(
        time_off=time_off, time_bleach=time_bleach
    )
    theoretical_lifespan_on = average_lifespan_on(
        time_off=time_off, time_bleach=time_bleach
    )
    # theoretical_lifespan = smlmsim.dynamics.average_lifespan(
    #     time_on=time_on, time_off=time_off, time_bleach=time_bleach
    # )

    estimated_num_activations = 0.0
    estimated_lifespan_on = 0.0
    estimated_lifespan = 0.0
    estimated_n_frames_on = 0.0
    for _ in range(N):
        periods = sample_periods(
            time_on=time_on,
            time_off=time_off,
            time_bleach=time_bleach,
            gen=gen,
        )
        estimated_num_activations += periods.shape[0]
        estimated_lifespan_on += (periods[:, 1] - periods[:, 0]).sum()
        estimated_lifespan += periods[-1, 1] - periods[0, 0]
        t_start = torch.rand((), generator=gen)
        periods += t_start
        fractions = discretize_periods(periods, n_frames=10)
        estimated_n_frames_on += (fractions > 0).sum().item()
    estimated_num_activations /= N
    estimated_lifespan_on /= N
    estimated_lifespan /= N
    estimated_n_frames_on /= N

    diff = abs(estimated_num_activations - theoretical_num_activations)
    diff /= 0.5 * (estimated_num_activations + theoretical_num_activations)
    assert diff <= 0.01

    diff = abs(estimated_lifespan_on - theoretical_lifespan_on)
    diff /= 0.5 * (estimated_lifespan_on + theoretical_lifespan_on)
    assert diff <= 0.01

    # diff = abs(estimated_lifespan - theoretical_lifespan)
    # diff /= 0.5 * (estimated_lifespan + theoretical_lifespan)
    # self.assertLessEqual(diff, 0.01)
