import numpy as np
import torch
import torchaudio
from hypothesis import given, settings
from hypothesis import strategies as st
from scipy.integrate import quad

from pdisvae import utils


@given(st.data())
@settings(deadline=None)
def test_exp_basis(data):
    """Check basis integrates to 1."""
    decay = data.draw(st.floats(min_value=0.1, max_value=10))
    window_size = data.draw(st.integers(min_value=1, max_value=10))
    time_span = data.draw(st.floats(min_value=0.1, max_value=10))

    basis = utils.exp_basis(decay, window_size, time_span)
    torch.testing.assert_close(basis.sum() * time_span / window_size, torch.tensor(1.0))


@given(st.data())
@settings(deadline=None)
def test_convolve_spikes_with_basis(data):
    """Test convolve_spikes_with_basis.

    Parameters
    ----------
    data : _hypothesis.strategies.data._DataObject
        Hypothesis data object.
    """
    torch.manual_seed(0)
    n_samples: int = data.draw(st.integers(1, 3))
    n_time_bins: int = 100
    n_neurons: int = data.draw(st.integers(1, 3))

    basis = utils.exp_basis(1, 5, 5)

    spikes_list = torch.poisson(torch.rand((n_samples, n_time_bins, n_neurons)))
    convolved_spikes = utils.convolve_spikes_with_basis(spikes_list, basis)

    convolved_spikes_ref = torch.zeros_like(convolved_spikes)
    for sample in range(n_samples):
        for neuron in range(n_neurons):
            convolved_spikes_ref[sample, :, neuron] = torchaudio.functional.convolve(
                torch.concat([torch.zeros([1]), spikes_list[sample, :, neuron]], dim=0),
                basis,
                mode="full",
            )[:n_time_bins]

    torch.testing.assert_close(convolved_spikes, convolved_spikes_ref)


def test_logcosh():
    """Test numerical stable logcosh."""

    x = torch.tensor([-2, 0, 2, 3, 5, 8, 10, 20, 50, 100, 1000])
    y = torch.cosh(x).log()
    y_ref = utils.logcosh(x)
    torch.testing.assert_close(y[:-2], y_ref[:-2])
    torch.testing.assert_close(y[-2:], torch.tensor([torch.inf, torch.inf]))


def test_logcosh_log_prob():
    """Test logcosh log probability satisfies desired distribution property."""

    # Check that the integral of the pdf is 1
    result = quad(lambda z: utils.logcosh_log_prob(torch.tensor(z)).exp(), 0, np.inf)
    np.testing.assert_allclose(result[0], 0.5)

    # Check that the variance is 1
    result = quad(
        lambda z: z**2 * utils.logcosh_log_prob(torch.tensor(z)).exp(), -np.inf, np.inf
    )
    np.testing.assert_allclose(result[0], 1.0, rtol=1e-5)
