"""
Tests for Kolm social welfare function and solver.

Key properties:
1. pow=0 gives weighted sum (utilitarian)
2. pow=-inf gives minimum (egalitarian)
3. Allocation probabilities sum to num_alloc
"""

import pytest
import torch
import numpy as np

from src.kolm import kolm_swf, KolmSolver


class TestKolmSWF:
    """Tests for the kolm_swf function."""

    def test_pow_0_is_weighted_sum(self):
        """With pow=0, Kolm SWF is weighted sum."""
        u = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64)
        weights = torch.tensor([0.5, 0.3, 0.2], dtype=torch.float64)
        result = kolm_swf(u, weights, pow=0)
        expected = weights @ u
        assert torch.isclose(result, expected, atol=1e-6)

    def test_pow_neg_inf_is_min(self):
        """With pow=-inf, Kolm SWF is minimum."""
        u = torch.tensor([0.5, 0.1, 0.3], dtype=torch.float64)
        weights = torch.tensor([0.3, 0.3, 0.4], dtype=torch.float64)
        result = kolm_swf(u, weights, pow=-torch.inf)
        assert torch.isclose(result, torch.tensor(0.1, dtype=torch.float64), atol=1e-6)

    def test_negative_pow_reduces_inequality_preference(self):
        """More negative pow should prefer more equal distributions."""
        weights = torch.tensor([0.5, 0.5], dtype=torch.float64)

        # Equal distribution
        u_equal = torch.tensor([0.5, 0.5], dtype=torch.float64)
        # Unequal distribution (same sum)
        u_unequal = torch.tensor([0.9, 0.1], dtype=torch.float64)

        # For negative pow, equal distribution should score higher
        for pow in [-0.5, -1.0, -2.0]:
            swf_equal = kolm_swf(u_equal, weights, pow)
            swf_unequal = kolm_swf(u_unequal, weights, pow)
            assert swf_equal > swf_unequal, f"Failed inequality aversion for pow={pow}"

    def test_rejects_positive_pow(self):
        """Kolm SWF requires pow <= 0."""
        u = torch.tensor([0.5, 0.5], dtype=torch.float64)
        weights = torch.tensor([0.5, 0.5], dtype=torch.float64)
        with pytest.raises(AssertionError):
            kolm_swf(u, weights, pow=1.0)


class TestKolmSolverInvariants:
    """Test invariants for KolmSolver."""

    @pytest.fixture
    def solver_setup(self):
        n_arms = 5
        num_alloc = 3
        weights = torch.softmax(torch.randn(n_arms, dtype=torch.float64), dim=0)
        return n_arms, num_alloc, weights

    @pytest.mark.parametrize("pow_val", [-torch.inf, -2.0, -1.0, -0.5])
    def test_allocation_sums_to_num_alloc(self, solver_setup, pow_val):
        """Allocation probabilities must sum to num_alloc."""
        n_arms, num_alloc, weights = solver_setup
        solver = KolmSolver(weights, pow_val, num_alloc)

        u = torch.rand(n_arms, dtype=torch.float64) + 0.1
        probs = solver.get_allocation_probabilities(u)

        assert torch.isclose(probs.sum(), torch.tensor(num_alloc, dtype=torch.float64), atol=1e-6), \
            f"Sum={probs.sum()}, expected={num_alloc}"

    @pytest.mark.parametrize("pow_val", [-torch.inf, -2.0, -1.0, -0.5])
    def test_probabilities_in_valid_range(self, solver_setup, pow_val):
        """All probabilities must be in [0, 1]."""
        n_arms, num_alloc, weights = solver_setup
        solver = KolmSolver(weights, pow_val, num_alloc)

        u = torch.rand(n_arms, dtype=torch.float64) + 0.1
        probs = solver.get_allocation_probabilities(u)

        assert torch.all(probs >= -1e-8), f"Negative probability: {probs.min()}"
        assert torch.all(probs <= 1 + 1e-8), f"Probability > 1: {probs.max()}"


class TestKolmSolverEdgeCases:
    """Test edge cases for KolmSolver."""

    def test_single_arm(self):
        """Single arm gets all allocation."""
        weights = torch.tensor([1.0], dtype=torch.float64)
        solver = KolmSolver(weights, pow=-1.0, num_alloc=1)

        u = torch.tensor([0.5], dtype=torch.float64)
        probs = solver.get_allocation_probabilities(u)

        assert torch.isclose(probs[0], torch.tensor(1.0, dtype=torch.float64), atol=1e-6)

    def test_pow_0_selects_top_k(self):
        """With pow=0, should greedily select top num_alloc by weighted utility."""
        weights = torch.tensor([0.4, 0.3, 0.2, 0.1], dtype=torch.float64)
        solver = KolmSolver(weights, pow=0, num_alloc=2)

        # weighted utilities: [0.4, 0.6, 0.4, 0.2]
        u = torch.tensor([1.0, 2.0, 2.0, 2.0], dtype=torch.float64)
        probs = solver.get_allocation_probabilities(u)

        assert probs.sum() == 2.0
