"""
Test for utilitarian and egalitarian limits of all solvers.
"""

import pytest
import torch

from src.kolm import KolmSolver
from src.wpm import WPMSolver
from src.gini import GiniSolver

class TestSolverLimits:
    """Tests for solver limits."""

    def test_utilitarian_limit(self):
        """Test that pow=0 gives utilitarian allocation."""
        seed = 42
        n_arms = 50
        num_alloc = 10
        weights = torch.ones(n_arms, dtype=torch.float64) / n_arms
        u = torch.rand(n_arms, dtype=torch.float64) * 10 + 1  # Avoid zeros

        torch.manual_seed(seed)
        wpm_pow = 1.0
        solver = WPMSolver(weights, wpm_pow, num_alloc)
        wpm_probs = solver.get_allocation_probabilities(u)
        wpm_swf = (wpm_probs * u).sum()

        torch.manual_seed(seed)
        kolm_pow = 0.0
        solver = KolmSolver(weights, kolm_pow, num_alloc)
        kolm_probs = solver.get_allocation_probabilities(u)
        kolm_swf = (kolm_probs * u).sum()

        torch.manual_seed(seed)
        solver = GiniSolver(weights, num_alloc)
        gini_probs = solver.get_allocation_probabilities(u)
        gini_swf = (gini_probs * u).sum()


        assert torch.isclose(wpm_swf, kolm_swf, atol=1e-5), "WPM and Kolm utilitarian SWFs differ."
        assert torch.isclose(wpm_swf, gini_swf, atol=1e-5), "WPM and Gini utilitarian SWFs differ."

    def test_egalitarian_limit(self):
        """Test that pow=-inf gives egalitarian allocation."""
        n_arms = 50
        num_alloc = 10
        weights = torch.ones(n_arms, dtype=torch.float64) / n_arms
        u = torch.rand(n_arms, dtype=torch.float64) * 10 + 1  # Avoid zeros

        wpm_pow = -torch.inf
        solver = WPMSolver(weights, wpm_pow, num_alloc)
        wpm_probs = solver.get_allocation_probabilities(u)
        wpm_swf = torch.min(wpm_probs * u)

        kolm_pow = -torch.inf
        solver = KolmSolver(weights, kolm_pow, num_alloc)
        kolm_probs = solver.get_allocation_probabilities(u)
        kolm_swf = torch.min(kolm_probs * u)

        weights = torch.tensor([1.0] + (n_arms - 1) * [0.0], dtype=torch.float64)
        solver = GiniSolver(weights, num_alloc)
        gini_probs = solver.get_allocation_probabilities(u)
        gini_swf = torch.min(gini_probs * u)

        print(f"""
        u: {u.numpy()}
        WPM SWF: {wpm_swf.item()}
        Kolm SWF: {kolm_swf.item()}
        Gini SWF: {gini_swf.item()}
              """)

        assert torch.isclose(wpm_swf, kolm_swf, atol=1e-5), "WPM and Kolm egalitarian SWFs differ."
        assert torch.isclose(wpm_swf, gini_swf, atol=1e-5), "WPM and Gini egalitarian SWFs differ."
