"""
Tests for Weighted Power Mean (WPM) solver.

Key invariants:
1. Allocation probabilities sum to num_alloc
2. All probabilities in [0, 1]
3. Water-filling produces valid allocations
4. Known analytical solutions for special cases
"""

import pytest
import torch
import numpy as np

from src.wpm import wpm_swf, WPMSolver


class TestWPMSWF:
    """Tests for the wpm_swf function."""

    def test_uniform_weights_pow_1_is_sum(self):
        """With pow=1 and uniform weights, WPM is proportional to sum."""
        u = torch.tensor([0.2, 0.3, 0.5])
        weights = torch.tensor([1/3, 1/3, 1/3])
        result = wpm_swf(u, weights, pow=1)
        expected = u.sum() / 3  # weighted sum with uniform weights
        assert torch.isclose(result, expected, atol=1e-6)

    def test_pow_0_is_geometric_mean(self):
        """With pow=0, WPM is weighted geometric mean."""
        u = torch.tensor([2.0, 8.0])
        weights = torch.tensor([0.5, 0.5])
        result = wpm_swf(u, weights, pow=0)
        expected = torch.tensor(4.0)  # sqrt(2 * 8) = 4
        assert torch.isclose(result, expected, atol=1e-6)

    def test_pow_neg_inf_is_min(self):
        """With pow=-inf, WPM is the minimum."""
        u = torch.tensor([0.1, 0.5, 0.3])
        weights = torch.tensor([0.2, 0.3, 0.5])
        result = wpm_swf(u, weights, pow=-torch.inf)
        # doubles?
        assert torch.isclose(result, torch.tensor(0.1), atol=1e-6)

    def test_pow_1_is_weighted_sum(self):
        """With pow=1, WPM is weighted arithmetic mean."""
        u = torch.tensor([1.0, 2.0, 3.0])
        weights = torch.tensor([0.5, 0.3, 0.2])
        result = wpm_swf(u, weights, pow=1)
        expected = 0.5 * 1.0 + 0.3 * 2.0 + 0.2 * 3.0
        assert torch.isclose(result, torch.tensor(expected), atol=1e-6)

    def test_monotonic_in_u(self):
        """SWF should increase when any u_i increases."""
        weights = torch.tensor([0.3, 0.3, 0.4])
        u1 = torch.tensor([0.2, 0.3, 0.4])
        u2 = torch.tensor([0.2, 0.3, 0.5])  # last element increased

        for pow in [-2.0, -1.0, 0.0, 0.5]:
            swf1 = wpm_swf(u1, weights, pow)
            swf2 = wpm_swf(u2, weights, pow)
            assert swf2 > swf1, f"Failed monotonicity for pow={pow}"


class TestWPMSolverInvariants:
    """Test invariants that should hold for any valid input."""

    @pytest.fixture
    def solver_setup(self):
        """Common setup for solver tests."""
        n_arms = 5
        num_alloc = 3
        weights = torch.softmax(torch.randn(n_arms, dtype=torch.float64), dim=0)
        return n_arms, num_alloc, weights

    @pytest.mark.parametrize("pow_val", [-torch.inf, -2.0, -1.0, -0.5, 0.0, 0.5])
    def test_allocation_sums_to_num_alloc(self, solver_setup, pow_val):
        """Allocation probabilities must sum to num_alloc."""
        n_arms, num_alloc, weights = solver_setup
        solver = WPMSolver(weights, pow_val, num_alloc)

        u = torch.rand(n_arms, dtype=torch.float64) + 0.1  # avoid zero
        probs = solver.get_allocation_probabilities(u)

        assert torch.isclose(probs.sum(), torch.tensor(num_alloc, dtype=torch.float64), atol=1e-6), \
            f"Sum={probs.sum()}, expected={num_alloc}"

    @pytest.mark.parametrize("pow_val", [-torch.inf, -2.0, -1.0, -0.5, 0.0, 0.5])
    def test_probabilities_in_valid_range(self, solver_setup, pow_val):
        """All probabilities must be in [0, 1]."""
        n_arms, num_alloc, weights = solver_setup
        solver = WPMSolver(weights, pow_val, num_alloc)

        u = torch.rand(n_arms, dtype=torch.float64) + 0.1
        probs = solver.get_allocation_probabilities(u)

        assert torch.all(probs >= -1e-8), f"Negative probability: {probs.min()}"
        assert torch.all(probs <= 1 + 1e-8), f"Probability > 1: {probs.max()}"

    def test_water_filling_respects_capacity(self):
        """Water filling should never exceed capacity of 1.0 per arm."""
        weights = torch.tensor([0.5, 0.3, 0.2], dtype=torch.float64)
        solver = WPMSolver(weights, pow=-1.0, num_alloc=2)

        u = torch.tensor([0.9, 0.5, 0.1], dtype=torch.float64)
        probs = solver.get_allocation_probabilities(u)

        assert torch.all(probs <= 1.0 + 1e-8)


class TestWPMSolverEdgeCases:
    """Test edge cases and boundary conditions."""

    def test_single_arm(self):
        """Single arm should get all allocation (capped at 1)."""
        weights = torch.tensor([1.0], dtype=torch.float64)
        solver = WPMSolver(weights, pow=-1.0, num_alloc=1)

        u = torch.tensor([0.5], dtype=torch.float64)
        probs = solver.get_allocation_probabilities(u)

        assert torch.isclose(probs[0], torch.tensor(1.0, dtype=torch.float64), atol=1e-6)

    def test_num_alloc_equals_n_arms(self):
        """When num_alloc == n_arms, all arms get probability 1."""
        n_arms = 3
        weights = torch.tensor([0.5, 0.3, 0.2], dtype=torch.float64)
        solver = WPMSolver(weights, pow=-1.0, num_alloc=n_arms)

        u = torch.tensor([0.3, 0.5, 0.2], dtype=torch.float64)
        probs = solver.get_allocation_probabilities(u)

        assert torch.allclose(probs, torch.ones(n_arms, dtype=torch.float64), atol=1e-6)

    def test_equal_utilities_symmetric_weights(self):
        """Equal utilities with symmetric weights should give equal allocations."""
        n_arms = 4
        weights = torch.ones(n_arms, dtype=torch.float64) / n_arms
        solver = WPMSolver(weights, pow=-1.0, num_alloc=2)

        u = torch.ones(n_arms, dtype=torch.float64) * 0.5
        probs = solver.get_allocation_probabilities(u)

        # All probabilities should be equal (2/4 = 0.5 each)
        assert torch.allclose(probs, torch.ones(n_arms, dtype=torch.float64) * 0.5, atol=1e-6)

    def test_pow_1_selects_top_k(self):
        """With pow=1 (utilitarian), should select top num_alloc arms by weighted utility."""
        weights = torch.tensor([0.4, 0.3, 0.2, 0.1], dtype=torch.float64)
        solver = WPMSolver(weights, pow=1.0, num_alloc=2)

        # weighted utilities: [0.4, 0.6, 0.4, 0.2]
        u = torch.tensor([1.0, 2.0, 2.0, 2.0], dtype=torch.float64)
        probs = solver.get_allocation_probabilities(u)

        # Arms 0 and 1 have highest weighted utility, but arm 1 is definitely highest
        # The exact allocation depends on tie-breaking
        assert probs.sum() == 2.0
        assert probs[1] == 1.0  # highest weighted utility
