import numpy as np
import pytest
from scipy.stats import norm


def test_quantile_function():
    from soe.utils import QuantileFunction

    # small sample
    samples = np.array([1, 2, 3, 4, 5])
    quantile_function = QuantileFunction(samples)

    assert quantile_function(0.0) == 1
    assert quantile_function(0.25) == 2
    assert quantile_function(0.5) == 3
    assert quantile_function(1.0) == 5
    assert np.allclose(quantile_function([0.0, 0.25, 0.5, 0.75, 1.0]), samples)

    # uniform distribution
    np.random.seed(0)

    samples = np.random.rand(int(1e6))
    quantile_function = QuantileFunction(samples)

    assert np.abs(quantile_function(0.5) - 0.5) < 1e-2
    assert quantile_function(0.0) == np.min(samples)
    assert quantile_function(1.0) == np.max(samples)

    # normal distribution
    np.random.seed(0)

    samples = np.random.randn(int(1e6))
    quantile_function = QuantileFunction(samples)

    p = np.linspace(0.01, 0.99, 99)
    assert np.abs(norm.ppf(p) - quantile_function(p)).mean() < 1e-2


def test_second_quantile():
    from soe.utils import SecondQuantileFunction

    # Small sample
    samples = np.array([1, 2, 3, 4])
    sec_quant = SecondQuantileFunction(samples)

    assert sec_quant(0.0) == 0.0

    # Normal distribution
    np.random.seed(0)

    samples = np.random.randn(int(1e6))
    sec_quant = SecondQuantileFunction(samples)

    assert sec_quant(0.0) == 0.0

    dx = sec_quant.dp
    x = np.linspace(0.0, 1.0, int(1.0 / dx))[1:-1]
    sec_quant_true = np.cumsum(norm.ppf(x) * dx)
    assert np.abs(sec_quant_true - sec_quant(x)).max() < 1e-2


def test_ecdf():
    from soe.utils import ECDF

    # small sample
    samples = np.array([1, 2, 3, 4, 5])
    cdf = ECDF(samples)

    assert cdf(0.0) == 0.0
    assert np.allclose(cdf(samples), np.arange(1, 6) / 5)

    # normal distribution
    np.random.seed(0)

    samples = np.random.randn(int(1e6))
    cdf = ECDF(samples)

    x = np.linspace(-5, 5, 100)
    assert np.mean(np.abs(norm.cdf(x) - cdf(x))) < 1e-2


def test_integrated_ecdf():
    from soe.utils import IntegratedECDF

    # small sample
    samples = np.array([1, 2, 3, 4, 5])
    icdf = IntegratedECDF(samples)

    assert icdf(0.0) == 0.0

    # exponential distribution
    np.random.seed(0)

    samples = np.random.exponential(scale=1.0, size=1000)
    icdf = IntegratedECDF(samples)

    for x in [0.5, 1.0, 1.5, 2.0]:
        # The true CDF of the exponential distribution is 1 - exp(-λx)
        # The true integrated CDF is the integral of the CDF from -inf to x, which for the exponential distribution is x - 1 + exp(-λx)
        true_icdf = x - 1 + np.exp(-x)
        np.testing.assert_almost_equal(icdf(x), true_icdf, decimal=2)

    # Vectorized form
    x = np.array([0.5, 1.0, 1.5, 2.0])
    true_icdf = x - 1 + np.exp(-x)
    np.testing.assert_almost_equal(icdf(x), true_icdf, decimal=2)


# Run test with mulitple values of x0 and x1
@pytest.mark.parametrize("x0, x1", [(0.0, 1.0), (2.0, 4.0)])
def test_num_integrate(x0, x1):
    from soe.utils import num_integrate, num_integrate_func

    dx = 0.001
    f = lambda x: x
    int_f = lambda x: x**2 / 2

    # Compare with analytical result: $\Int x dx = x^2 / 2$
    res = num_integrate(f, x0, x1, dx)
    gt = int_f(x1) - int_f(x0)
    assert np.abs(res - gt) < 1e-2

    # Test indefinite integral
    num_int_f = num_integrate_func(f, x0, x1, dx)
    for xi in np.linspace(x0, x1, 10):
        gt = int_f(xi) - int_f(x0)
        assert np.abs(num_int_f(xi) - gt) < 1e-2

def test_pdist():
    import torch
    from soe.utils import pdist

    X, Y = np.random.randn(3, 2), np.random.randn(5, 2)
    d = pdist(X, Y)
    assert d.shape == (X.shape[0], Y.shape[0])
    assert np.allclose(d, ((X[:, None] - Y)**2).sum(axis=-1), atol=1e-8)

    X, Y = torch.randn(3, 2), torch.randn(5, 2)
    d = pdist(X, Y)
    assert d.shape == (X.shape[0], Y.shape[0])
    assert torch.allclose(d, ((X[:, None] - Y)**2).sum(dim=-1), atol=1e-8)

    X, Y = torch.randn(3, 2), torch.randn(5, 2)
    beta = 8.0
    d = pdist(X, Y, metric='logistic', beta=beta)
    assert d.shape == (X.shape[0], Y.shape[0])
    logsigmoid = lambda x: np.log(1 + np.exp(-beta * x))
    d_np = logsigmoid(X[:, None] - Y).numpy().sum(axis=-1)
    assert np.allclose(d.numpy(), d_np, atol=1e-8)
