"""
Integration tests for the full experiment pipeline.

These tests run small-scale versions of the actual experiments
to catch issues in component interactions.
"""

import pytest
import torch
import numpy as np

from src.wpm import WPMSolver, wpm_swf
from src.kolm import KolmSolver, kolm_swf
from src.gini import GiniSolver, gini_swf
from src.armsets import BetaArmSet
from src.sampler import PiPSSampler
from src.swf_ucb import SWFUCB

# Import the core experiment function for reproducibility tests
from ucb_expt import run_experiments_core, generate_weights, get_solver_and_swf


class TestBetaArmSet:
    """Tests for the BetaArmSet arm environment."""

    def test_means_in_valid_range(self):
        """Arm means should be in (0, 1) for Beta distribution."""
        gen = torch.Generator().manual_seed(42)
        arm_set = BetaArmSet(n_arms=10, gen=gen)

        assert torch.all(arm_set.means > 0)
        assert torch.all(arm_set.means < 1)

    def test_samples_in_valid_range(self):
        """Samples should be in [0, 1]."""
        gen = torch.Generator().manual_seed(42)
        arm_set = BetaArmSet(n_arms=5, gen=gen)

        inds = np.array([0, 2, 4])
        samples = arm_set.sample(inds)

        assert torch.all(samples >= 0)
        assert torch.all(samples <= 1)
        assert len(samples) == len(inds)


class TestSWFUCBIntegration:
    """Integration tests for the SWFUCB algorithm."""

    @pytest.mark.parametrize("objective", ["wpm", "kolm", "gini"])
    def test_ucb_runs_without_error(self, objective):
        """UCB should run without errors for all objectives."""
        n_arms = 5
        n_rounds = 50
        num_alloc = 2

        gen = torch.Generator().manual_seed(42)
        arm_set = BetaArmSet(n_arms, gen=gen)

        weights = torch.softmax(torch.randn(n_arms, dtype=torch.float64), dim=0)

        if objective == "wpm":
            solver = WPMSolver(weights, pow=-1.0, num_alloc=num_alloc)
        elif objective == "kolm":
            solver = KolmSolver(weights, pow=-1.0, num_alloc=num_alloc)
        elif objective == "gini":
            weights, _ = torch.sort(weights, descending=True)
            solver = GiniSolver(weights, num_alloc=num_alloc)

        sampler = PiPSSampler(gen)
        ucb = SWFUCB(n_arms, num_alloc, solver, sampler, delta=0.05)

        for t in range(n_rounds):
            inds = ucb.select_arms()
            rewards = arm_set.sample(inds)
            ucb.update(inds, rewards)

        # Check final state is valid
        assert ucb.t == n_rounds
        assert torch.all(ucb.arm_counts >= 0)
        assert torch.all(ucb.probs >= 0)
        assert torch.all(ucb.probs <= 1 + 1e-6)

    def test_ucb_explores_all_arms_initially(self):
        """UCB should explore all arms in the initial phase."""
        n_arms = 6
        num_alloc = 2

        gen = torch.Generator().manual_seed(42)
        arm_set = BetaArmSet(n_arms, gen=gen)
        weights = torch.softmax(torch.randn(n_arms, dtype=torch.float64), dim=0)
        solver = WPMSolver(weights, pow=-1.0, num_alloc=num_alloc)
        sampler = PiPSSampler(gen)

        ucb = SWFUCB(n_arms, num_alloc, solver, sampler, delta=0.05)

        # Run enough rounds to cover all arms (ceil(n_arms / num_alloc) = 3)
        for t in range(3):
            inds = ucb.select_arms()
            rewards = arm_set.sample(inds)
            ucb.update(inds, rewards)

        # All arms should have been pulled at least once
        assert torch.all(ucb.arm_counts >= 1)

    def test_ucb_counts_increase(self):
        """Arm counts should monotonically increase."""
        n_arms = 5
        n_rounds = 20
        num_alloc = 2

        gen = torch.Generator().manual_seed(42)
        arm_set = BetaArmSet(n_arms, gen=gen)
        weights = torch.softmax(torch.randn(n_arms, dtype=torch.float64), dim=0)
        solver = WPMSolver(weights, pow=-1.0, num_alloc=num_alloc)
        sampler = PiPSSampler(gen)

        ucb = SWFUCB(n_arms, num_alloc, solver, sampler, delta=0.05)

        prev_total = 0
        for t in range(n_rounds):
            inds = ucb.select_arms()
            rewards = arm_set.sample(inds)
            ucb.update(inds, rewards)

            current_total = ucb.arm_counts.sum().item()
            assert current_total >= prev_total
            prev_total = current_total


class TestRegretComputation:
    """Test that regret computation makes sense."""

    @pytest.mark.parametrize(
        "objective,swf_func,solver_cls,pow_val",
        [
            ("wpm", wpm_swf, WPMSolver, -1.0),
            ("kolm", kolm_swf, KolmSolver, -1.0),
        ],
    )
    def test_optimal_swf_is_achievable(self, objective, swf_func, solver_cls, pow_val):
        """Optimal allocation should give highest SWF."""
        n_arms = 5
        num_alloc = 2

        gen = torch.Generator().manual_seed(42)
        arm_set = BetaArmSet(n_arms, gen=gen)
        weights = torch.softmax(torch.randn(n_arms, dtype=torch.float64), dim=0)

        solver = solver_cls(weights, pow_val, num_alloc)
        opt_probs = solver.get_allocation_probabilities(arm_set.means)
        opt_swf = swf_func(arm_set.means * opt_probs, weights, pow_val)

        # Try some random allocations - they should all be worse
        for _ in range(100):
            random_probs = torch.rand(n_arms, dtype=torch.float64)
            random_probs = random_probs / random_probs.sum() * num_alloc
            random_probs = torch.clamp(random_probs, 0, 1)

            # Normalize to sum to num_alloc while staying in [0,1]
            while abs(random_probs.sum() - num_alloc) > 1e-6:
                diff = num_alloc - random_probs.sum()
                adjustable = (random_probs > 0) & (random_probs < 1)
                if adjustable.sum() == 0:
                    break
                random_probs[adjustable] += diff / adjustable.sum()
                random_probs = torch.clamp(random_probs, 0, 1)

            random_swf = swf_func(arm_set.means * random_probs, weights, pow_val)
            assert random_swf <= opt_swf + 1e-6, (
                f"Random allocation beat optimal: {random_swf} > {opt_swf}"
            )


class TestHelperFunctions:
    """Test the helper functions extracted from the experiment pipeline."""

    def test_get_solver_and_swf_wpm(self):
        """get_solver_and_swf returns correct classes for WPM."""
        solver_cls, swf_func = get_solver_and_swf("wpm")
        assert solver_cls is WPMSolver
        assert swf_func is wpm_swf

    def test_get_solver_and_swf_kolm(self):
        """get_solver_and_swf returns correct classes for Kolm."""
        solver_cls, swf_func = get_solver_and_swf("kolm")
        assert solver_cls is KolmSolver
        assert swf_func is kolm_swf

    def test_get_solver_and_swf_gini(self):
        """get_solver_and_swf returns correct classes for Gini."""
        solver_cls, swf_func = get_solver_and_swf("gini")
        assert solver_cls is GiniSolver
        assert swf_func is gini_swf

    def test_get_solver_and_swf_invalid(self):
        """get_solver_and_swf raises for unknown objective."""
        with pytest.raises(ValueError, match="Unknown objective"):
            get_solver_and_swf("invalid_objective")

    def test_generate_weights_on_simplex(self):
        """Generated weights should sum to 1."""
        gen = torch.Generator().manual_seed(42)
        weights = generate_weights(gen, n_arms=10, objective="wpm")

        assert torch.isclose(weights.sum(), torch.tensor(1.0, dtype=torch.float64))
        assert torch.all(weights >= 0)

    def test_generate_weights_gini_sorted(self):
        """Gini weights should be sorted in descending order."""
        gen = torch.Generator().manual_seed(42)
        weights = generate_weights(gen, n_arms=10, objective="gini")

        for i in range(len(weights) - 1):
            assert weights[i] >= weights[i + 1], "Gini weights must be non-increasing"


class TestReproducibility:
    """
    Test that experiments are reproducible with the same seed.

    These tests use run_experiments_core directly, which is the actual
    experiment logic. This ensures that if the experiment code changes,
    these tests will catch any reproducibility regressions.
    """

    def test_same_seed_same_weights(self):
        """Same seed should generate identical weights."""
        gen1 = torch.Generator().manual_seed(42)
        gen2 = torch.Generator().manual_seed(42)

        weights1 = generate_weights(gen1, n_arms=10, objective="wpm")
        weights2 = generate_weights(gen2, n_arms=10, objective="wpm")

        assert torch.allclose(weights1, weights2)

    def test_same_seed_same_results(self):
        """
        Running run_experiments_core with same seed should give identical results.

        This is the key reproducibility test - it uses the ACTUAL experiment
        function rather than recreating the logic, so any changes to the
        experiment code will be caught here.
        """
        # Use small experiment for speed
        common_params = dict(
            n_experiments=2,
            n_arms=5,
            n_rounds=30,
            num_alloc=2,
            delta=0.05,
            objective="wpm",
            pow_val=-1.0,
            parallel=False,  # Sequential for deterministic ordering
            verbose=False,  # Suppress output in tests
        )

        result1 = run_experiments_core(seed=42, **common_params)
        result2 = run_experiments_core(seed=42, **common_params)

        # All raw regrets should be identical
        assert np.allclose(result1.all_regrets, result2.all_regrets), (
            "Raw regrets differ between runs with same seed"
        )

        # Derived statistics should also match
        assert np.allclose(result1.avg_regret, result2.avg_regret)
        assert np.allclose(result1.avg_cum_regret, result2.avg_cum_regret)
        assert np.allclose(result1.std_error, result2.std_error)

        # Weights should be identical
        assert torch.allclose(result1.weights, result2.weights)

    def test_different_seeds_different_results(self):
        """Different seeds should produce different results."""
        common_params = dict(
            n_experiments=1,
            n_arms=5,
            n_rounds=30,
            num_alloc=2,
            delta=0.05,
            objective="wpm",
            pow_val=-1.0,
            parallel=False,
            verbose=False,
        )

        result1 = run_experiments_core(seed=42, **common_params)
        result2 = run_experiments_core(seed=43, **common_params)

        # Results should differ
        assert not np.allclose(result1.all_regrets, result2.all_regrets), (
            "Different seeds produced identical results - RNG may not be working"
        )

        # Weights should also differ
        assert not torch.allclose(result1.weights, result2.weights)

    @pytest.mark.parametrize("objective", ["wpm", "kolm", "gini"])
    def test_reproducibility_all_objectives(self, objective):
        """Reproducibility should hold for all objective functions."""
        pow_val = -1.0 if objective != "gini" else 0.0  # Gini ignores pow_val

        common_params = dict(
            n_experiments=2,
            n_arms=5,
            n_rounds=20,
            num_alloc=2,
            delta=0.05,
            objective=objective,
            pow_val=pow_val,
            parallel=False,
            verbose=False,
        )

        result1 = run_experiments_core(seed=123, **common_params)
        result2 = run_experiments_core(seed=123, **common_params)

        assert np.allclose(result1.all_regrets, result2.all_regrets), (
            f"Reproducibility failed for objective={objective}"
        )


class TestExperimentResultsStructure:
    """Test that ExperimentResults contains expected data shapes and values."""

    def test_results_shapes(self):
        """Result arrays should have expected shapes."""
        n_experiments = 3
        n_rounds = 50
        n_arms = 5

        results = run_experiments_core(
            n_experiments=n_experiments,
            n_arms=n_arms,
            n_rounds=n_rounds,
            num_alloc=2,
            delta=0.05,
            seed=42,
            objective="wpm",
            pow_val=-1.0,
            parallel=False,
            verbose=False,
        )

        # Check shapes
        assert results.all_regrets.shape == (n_experiments, n_rounds)
        assert results.avg_regret.shape == (n_rounds,)
        assert results.std_error.shape == (n_rounds,)
        assert results.avg_cum_regret.shape == (n_rounds,)
        assert results.std_cum_error.shape == (n_rounds,)
        assert results.weights.shape == (n_arms,)

    def test_regret_is_non_negative(self):
        """Regret should be non-negative (we're comparing to optimal)."""
        results = run_experiments_core(
            n_experiments=2,
            n_arms=5,
            n_rounds=100,
            num_alloc=2,
            delta=0.05,
            seed=42,
            objective="wpm",
            pow_val=-1.0,
            parallel=False,
            verbose=False,
        )

        # Per-step regret might occasionally be slightly negative due to
        # numerical precision, but average should be non-negative
        assert results.avg_regret.mean() >= -1e-6, "Average regret should be non-negative"

    def test_cumulative_regret_is_monotonic(self):
        """Cumulative regret should be non-decreasing."""
        results = run_experiments_core(
            n_experiments=2,
            n_arms=5,
            n_rounds=50,
            num_alloc=2,
            delta=0.05,
            seed=42,
            objective="wpm",
            pow_val=-1.0,
            parallel=False,
            verbose=False,
        )

        # Check that cumulative regret never decreases
        diffs = np.diff(results.avg_cum_regret)
        assert np.all(diffs >= -1e-6), "Cumulative regret should be non-decreasing"
