import pytest
import numpy as np
import numpy.testing as npt
import kdtpp.metropolis as metropolis
import scipy.stats as stats


def test_quantile_boundaries():
    """
    Test quantile_boundaries function.

    Tests that:
        1. Some basic known cases are correct.
           1.1. A single state has boundaries at -∞ and ∞.
           1.2. Two states have boundaries at -∞, 0, and ∞.
        2. Some error cases are handled correctly.
           2.1. n must be a positive integer.
           2.2. tail_prob must be between 0 and 0.5.
        3. Tests invariants.
            3.1 n-2 states with tail_prob set to the 2nd and 2nd last quantiles
            should have the boundaries equal to the inner boundaries of n states
            with no tail_prob set.
    """
    # Test 1.
    # Test 1.1.
    # A single state has boundaries at -∞ and ∞.
    bs = metropolis.quantile_boundaries(1)
    npt.assert_array_equal(bs, np.array([-np.inf, np.inf]))

    # Test 1.2.
    # Two states have boundaries at -∞, 0, and ∞.
    bs = metropolis.quantile_boundaries(2)
    npt.assert_array_equal(bs, np.array([-np.inf, 0, np.inf]))

    # Test 2.
    # Test 2.1.
    # Test n must be a positive integer.
    with pytest.raises(ValueError):
        metropolis.quantile_boundaries(-1)
    with pytest.raises(ValueError):
        metropolis.quantile_boundaries(0)
    with pytest.raises(ValueError):
        metropolis.quantile_boundaries(1.2)

    # Test 2.2.
    # Test tail_prob must be between 0 and 0.5.
    with pytest.raises(ValueError):
        metropolis.quantile_boundaries(10, -0.1)
    with pytest.raises(ValueError):
        metropolis.quantile_boundaries(10, 0.6)
    with pytest.raises(ValueError):
        metropolis.quantile_boundaries(10, 0.5)
    with pytest.raises(ValueError):
        metropolis.quantile_boundaries(10, 1.1)

    # Test 3.
    # Test 3.1.
    # n-2 states with tail probs and n states.
    max_n = 100
    for n in range(3, max_n):
        bs = metropolis.quantile_boundaries(n)
        bs2 = metropolis.quantile_boundaries(n - 2, 1 / n)
        npt.assert_array_almost_equal(bs[1:-1], bs2, decimal=5, err_msg=f"{n=}")


def test_from_rel_pos(np_rng):
    """Test from_rel_pos function (the matrix version).

    Tests that:
        1. A simple case works.
        2. Invariants
            2.1. Matrix version is same as single version.
    """
    # Test 1.
    # A simple case works.
    bs = np.array([-1, 0, 1])
    rel_pos = 0.5
    npt.assert_array_almost_equal(
        metropolis.from_rel_pos(rel_pos, bs), np.array([-0.5, 0.5])
    )

    # Test 2.
    n_trials = 100
    MAX_STATES = 100
    for t in range(n_trials):
        n_states = int(np_rng.integers(2, MAX_STATES))
        bs = metropolis.quantile_boundaries(n_states, tail_prob=0.0001)
        rel_pos = np_rng.uniform(0, 1)
        ps = metropolis.from_rel_pos(rel_pos, bs)
        ps2 = np.array(
            [
                metropolis.from_rel_pos_single(rel_pos, bs, i)
                for i in range(n_states)
            ]
        )
        npt.assert_array_almost_equal(ps, ps2, err_msg=f"{t=}")


def test_GridMetropolis_minimal(np_rng):
    """Test the GridMetropolis class.

    A very simple situation where there is just 1 state, and the target and
    inner pdfs are the standard normal distribution.
    """
    gm = metropolis.GridMetropolis(
        target_pdf=stats.norm.pdf,
        boundaries=np.array([-10, 10]),
        T=np.identity(1),
        initial_x=0,
        n_warmup=200,
        rng=np_rng,
        inner_sample_fn=np_rng.normal,
        inner_pdf=stats.norm.pdf,
    )
    samples, _ = gm.sample(size=2000)
    stat, pval = stats.normaltest(samples)
    assert pval > 0.05, "Samples should be normal"


@pytest.mark.parametrize("scale", [0.25, 0.5, 1]) # Fails for >=2
def test_GridMetropolis_minimal_scale(np_rng, scale):
    """Test the GridMetropolis class.

    Same as minimal, but we also use a scale parameter for the inner pdf.
    """
    gm = metropolis.GridMetropolis(
        target_pdf=stats.norm.pdf,
        boundaries=np.array([-100, 100]),
        T=np.identity(1),
        initial_x=0,
        n_warmup=200,
        rng=np.random.default_rng(42),
        inner_sample_fn=lambda loc : np_rng.normal(loc=loc, scale=scale),
        inner_pdf=lambda x : stats.norm.pdf(x, loc=0, scale=scale),
    )
    samples, _ = gm.sample(size=5000)
    stat, pval = stats.normaltest(samples)
    assert pval > 0.05, "Samples should be normal"


def test_GridMetropolis():
    """Test the GridMetropolis class."""
    # Test 1.
    # A simple case works.
    n_states = 16
    gm = metropolis.GridMetropolis(
        target_pdf=stats.norm.pdf,
        boundaries=metropolis.quantile_boundaries(n_states, tail_prob=0.00001),
        T=metropolis.primary_plus_distance_T(n_states, primary_p=0.6),
        initial_x=0,
        n_warmup=0,
        rng=np.random.default_rng(42),
        inner_sample_fn=np.random.default_rng(43).normal,
        inner_pdf=stats.norm.pdf,
    )
    samples, _ = gm.sample(size=5000)
    stat, pval = stats.normaltest(samples)
    assert pval > 0.05, "Samples should be normal"

def test_GridMetropolis_sample():
    """Test the sample method of GridMetropolis."""
    # Test 1.
    # Array samples work.
    n_states = 16
    gm = metropolis.GridMetropolis(
        target_pdf=stats.norm.pdf,
        boundaries=metropolis.quantile_boundaries(n_states, tail_prob=0.00001),
        T=metropolis.primary_plus_distance_T(n_states, primary_p=0.6),
        initial_x=0,
        n_warmup=0,
        rng=np.random.default_rng(42),
        inner_sample_fn=np.random.default_rng(43).normal,
        inner_pdf=stats.norm.pdf,
    )
    loc = np.zeros((10, 40, 2))
    samples, _ = gm.sample(loc, size=3)
    assert samples.shape == (3, 10, 40, 2)
    samples = samples.flatten()


@pytest.mark.parametrize("scale", [0.25, 0.5, 2])
def test_transition_prob(np_rng, scale):
    """Test transition_prob function."""
    # Test 1.
    # A simple case works.
    # TODO

    # Test 2.
    # Test invariants.
    # Test 2.1.
    # When the transition matrix is the identity matrix, the forward and
    # backward transition probabilities should be equal when using a symmetric
    # pdf, except for any xx points that are outside the leftmost and
    # rightmost boundaries, which we won't test here.
    n_trials = 500
    MAX_STATES = 100
    for t in range(n_trials):
        n_states = int(np_rng.integers(2, MAX_STATES))
        bs = metropolis.quantile_boundaries(n_states, tail_prob=0.0001)
        x = np_rng.uniform(bs[0], bs[-1])
        xx = np_rng.uniform(bs[0], bs[-1])
        T = np.eye(n_states)
        # As T is identity, the inter state must be x's state.
        inter_state = metropolis.pos_to_state(x, bs)
        pdf = lambda x: stats.norm.pdf(x, loc=0, scale=scale)
        forward, backward = metropolis.transition_prob(
            x, inter_state, xx, bs, T, pdf
        )
        assert np.isclose(
            forward, backward
        ), f"{t=} | {x=}, {xx=}, {inter_state=}, {bs=}"


def test_primary_plus_distance_T():
    """Test primary_plus_distance_T function.

    Test generated by Claude Sonnet 3.5.

    Tests that:
        1. Basic properties of transition matrices are maintained
           1.1. Row sums equal 1
           1.2. All entries are non-negative
           1.3. Matrix shape is correct
        2. Primary and distance-based transition properties
           2.1. Each row has a primary transition state with zero distance probability
           2.2. Other states have distance-based probabilities
           2.3. Works with different primary_p values
        3. Error cases are handled correctly
           3.1. n must be a positive integer
           3.2. primary_p must be between 0 and 1
        4. Random seed produces consistent results
    """
    # Test 1: Basic properties
    n = 5
    T = metropolis.primary_plus_distance_T(n)

    # Test 1.1: Row sums
    npt.assert_array_almost_equal(
        T.sum(axis=1), np.ones(n), decimal=10, err_msg="Row sums must equal 1"
    )

    # Test 1.2: Non-negative entries
    assert np.all(T >= 0), "All entries must be non-negative"

    # Test 1.3: Shape
    assert T.shape == (n, n), f"Shape must be ({n}, {n})"

    # Test 2: Primary and distance-based transitions
    primary_p = 0.8
    T = metropolis.primary_plus_distance_T(
        n, primary_p=primary_p, rng=np.random.default_rng(42)
    )

    # Test 2.1 & 2.2: For each row, verify structure
    next_states = []
    for i in range(n):
        # Find the primary transition state (one with zero distance probability)
        distances = np.abs(np.arange(n) - i)
        base_probs = 1 / distances
        base_probs[i] = 0
        base_probs = base_probs / np.sum(base_probs)

        # The state where actual probability differs most from distance-based
        # should be the primary transition state
        diff = np.abs(T[i] - (1 - primary_p) * base_probs)
        next_state = np.argmax(diff)
        next_states.append(next_state)

        # Verify that all other states follow distance-based pattern
        other_states = [j for j in range(n) if j != next_state]
        if len(other_states) > 1:
            other_probs = T[i, other_states]
            distances = np.abs(np.array(other_states) - i)
            # Check probabilities are inversely related to distance
            for k in range(len(other_states) - 1):
                if distances[k] > distances[k + 1]:
                    assert (
                        other_probs[k] < other_probs[k + 1]
                    ), "Probabilities should be inversely related to distance"
                elif distances[k] < distances[k + 1]:
                    assert (
                        other_probs[k] > other_probs[k + 1]
                    ), "Probabilities should be inversely related to distance"

    # Verify next_states is a permutation
    assert (
        len(set(next_states)) == n
    ), "Primary transitions should be a permutation"

    # Test 2.3: Different primary_p values work
    for p in [0.1, 0.5, 0.9]:
        T = metropolis.primary_plus_distance_T(n, primary_p=p)
        assert np.allclose(
            T.sum(axis=1), 1
        ), f"Rows must sum to 1 for primary_p={p}"

    # Test 3: Error cases
    # Test 3.1: Invalid n
    for invalid_n in [-1, 0, 1.5]:
        with pytest.raises(ValueError):
            metropolis.primary_plus_distance_T(invalid_n)

    # Test 3.2: Invalid primary_p
    for invalid_p in [-0.1, 0, 1, 1.1]:
        with pytest.raises(ValueError):
            metropolis.primary_plus_distance_T(n, primary_p=invalid_p)

    # Test 4: Random seed consistency
    rng1 = np.random.default_rng(42)
    rng2 = np.random.default_rng(42)
    T1 = metropolis.primary_plus_distance_T(n, rng=rng1)
    T2 = metropolis.primary_plus_distance_T(n, rng=rng2)
    npt.assert_array_equal(
        T1, T2, err_msg="Same seed should produce same results"
    )


def test_pos_to_state():
    """Test pos_to_state function.

    Test generated by Claude Sonnet 3.5.

    Tests that:
        1. Basic functionality
           1.1. Correct state assignment for interior points
           1.2. Correct state assignment for boundary points
        2. Edge cases
           2.1. Clipping for values beyond boundaries
           2.2. Handling of infinite boundaries
        3. Error cases and assertions
    """
    # Test 1: Basic functionality
    boundaries = np.array([-2, -1, 0, 1, 2])

    # Test 1.1: Interior points
    assert metropolis.pos_to_state(-1.5, boundaries) == 0
    assert metropolis.pos_to_state(-0.5, boundaries) == 1
    assert metropolis.pos_to_state(0.5, boundaries) == 2
    assert metropolis.pos_to_state(1.5, boundaries) == 3

    # Test 1.2: Boundary points
    assert metropolis.pos_to_state(-2, boundaries) == 0
    assert metropolis.pos_to_state(-1, boundaries) == 1
    assert metropolis.pos_to_state(0, boundaries) == 2
    assert metropolis.pos_to_state(1, boundaries) == 3
    assert metropolis.pos_to_state(2, boundaries) == 3

    # Test 2: Edge cases
    boundaries_with_inf = np.array([-np.inf, -1, 0, 1, np.inf])

    # Test 2.1: Clipping
    assert metropolis.pos_to_state(-1000, boundaries) == 0
    assert metropolis.pos_to_state(1000, boundaries) == 3

    # Test 2.2: Infinite boundaries
    assert metropolis.pos_to_state(-1000, boundaries_with_inf) == 0
    assert metropolis.pos_to_state(1000, boundaries_with_inf) == 3


def test_sample_next_state():
    """Test sample_next_state function.

    Test generated by Claude Sonnet 3.5.

    Tests that:
        1. Basic functionality
           1.1. Returns valid next state
        2. Error handling
           2.1. Incompatible shapes
        3. Random seed consistency
    """
    boundaries = np.array([-1, 0, 1])
    T = np.array([[0.7, 0.3], [0.3, 0.7]])
    rng = np.random.default_rng(42)

    # Test 1: Basic functionality
    # Test 1.1 & 1.2: Valid state and probabilities
    x = 0.5  # should be in state 1
    next_state = metropolis.sample_next_state(x, boundaries, T, rng)
    assert 0 <= next_state < len(T)

    # Test 2: Error handling
    # Test 2.1: Incompatible shapes
    with pytest.raises(ValueError):
        metropolis.sample_next_state(x, boundaries, T[:-1], rng)

    # Test 3: Random seed consistency
    rng1 = np.random.default_rng(42)
    rng2 = np.random.default_rng(42)
    result1 = metropolis.sample_next_state(x, boundaries, T, rng1)
    result2 = metropolis.sample_next_state(x, boundaries, T, rng2)
    assert result1 == result2


def test_to_rel_pos():
    """Test to_rel_pos function.

    Test generated by Claude Sonnet 3.5.

    Tests that:
        1. Basic functionality
           1.1. Correct relative position for interior points
           1.2. Correct relative position for boundary points
        2. Edge cases
           2.1. Error if any boundaries are infinite
        3. Range validation
           3.1. Output always between 0 and 1
    """
    # Test 1: Basic functionality
    boundaries = np.array([-1, 0, 1])

    # Test 1.1: Interior points
    npt.assert_almost_equal(metropolis.to_rel_pos(-0.5, boundaries), 0.5)
    npt.assert_almost_equal(metropolis.to_rel_pos(0.5, boundaries), 0.5)

    # Test 1.2: Boundary points
    npt.assert_almost_equal(metropolis.to_rel_pos(-1, boundaries), 1.0)
    npt.assert_almost_equal(metropolis.to_rel_pos(0, boundaries), 0.0)
    npt.assert_almost_equal(metropolis.to_rel_pos(1, boundaries), 1.0)

    # Test 2: Edge cases
    boundaries_with_inf = np.array([-np.inf, 0, np.inf])
    with pytest.raises(ValueError):
        metropolis.to_rel_pos(-100, boundaries_with_inf)

    # Test 3: Range validation
    test_points = np.linspace(-2, 2, 10)
    for x in test_points:
        rel_pos = metropolis.to_rel_pos(x, boundaries)
        assert (
            0 <= rel_pos <= 1
        ), f"Relative position {rel_pos} out of bounds for x={x}"


def test_from_rel_pos_single():
    """Test from_rel_pos function.

    Test generated by Claude Sonnet 3.5.

    Tests that:
        1. Basic functionality
           1.1. Correct position for interior points
           1.2. Correct position for boundary points
        2. Edge cases
           2.1. Infinite boundaries
        3. Roundtrip conversion
           3.1. Converting to and from relative position preserves value
    """
    # Test 1: Basic functionality
    bs = np.array([-1, 0, 1])

    # Test 1.1: Interior points
    npt.assert_almost_equal(metropolis.from_rel_pos_single(0.5, bs, 0), -0.5)
    npt.assert_almost_equal(metropolis.from_rel_pos_single(0.5, bs, 1), 0.5)

    # Test 1.2: Boundary points
    npt.assert_almost_equal(metropolis.from_rel_pos_single(0.0, bs, 0), 0.0)
    npt.assert_almost_equal(metropolis.from_rel_pos_single(1.0, bs, 0), -1.0)
    npt.assert_almost_equal(metropolis.from_rel_pos_single(0.0, bs, 1), 0.0)
    npt.assert_almost_equal(metropolis.from_rel_pos_single(1.0, bs, 1), 1.0)

    # Test 2: Edge cases
    boundaries_with_inf = np.array([-np.inf, 0, np.inf])
    with pytest.raises(ValueError):
        metropolis.from_rel_pos_single(0.5, boundaries_with_inf, 0)

    # Test 3: Roundtrip conversion
    test_points = [-0.5, 0.0, 0.5]
    for x in test_points:
        state = metropolis.pos_to_state(x, bs)
        rel_pos = metropolis.to_rel_pos(x, bs)
        x_new = metropolis.from_rel_pos_single(rel_pos, bs, state)
        npt.assert_almost_equal(x, x_new)
