"""
Tests for Generalized Gini Index solver.

Key properties:
1. Weights must be non-increasing
2. SWF is weighted sum of sorted utilities
3. Allocation probabilities sum to num_alloc
4. Solver results in solutions similar to CVXPY LP solution
"""

import pytest
import torch
import numpy as np
import cvxpy as cp

from src.gini import gini_swf, GiniSolver


def solve_gini_lp(mu, w, k, solver=None, verbose=False):
    """
    Solve:
        max   sum_i w_i * mu_i * p_i
        s.t.  0 <= p_i <= 1
              sum_i p_i = k
              mu_{i-1} p_{i-1} <= mu_i p_i,  i=2..n

    Parameters
    ----------
    mu : array-like, shape (n,)
        Nonnegative utilities, assumed sorted nondecreasing.
    w : array-like, shape (n,)
        Weights (rank weights or whatever matches your formulation).
    k : float or int
        Total allocation mass (typically integer in [0,n], but LP allows any 0<=k<=n).
    solver : str or None
        Optional CVXPY solver name, e.g. "ECOS", "OSQP", "GLPK", "GUROBI", "MOSEK".
    verbose : bool
        Solver verbosity.

    Returns
    -------
    p_star : np.ndarray, shape (n,)
        Optimal allocation.
    obj_val : float
        Optimal objective value.
    status : str
        CVXPY status string.
    """
    mu = np.asarray(mu, dtype=float).reshape(-1)
    w = np.asarray(w, dtype=float).reshape(-1)
    n = mu.size
    assert w.size == n, "mu and w must have the same length"
    assert 0 <= k <= n, "feasibility requires 0 <= k <= n"

    # Optional sanity check: sorted mu
    if np.any(mu[1:] < mu[:-1] - 1e-12):
        raise ValueError(
            "mu must be sorted nondecreasing for this formulation (or sort and permute w)."
        )

    p = cp.Variable(n)

    constraints = [
        p >= 0,
        p <= 1,
        cp.sum(p) == k,
    ]
    # monotonicity constraints on the products mu_i p_i
    for i in range(1, n):
        constraints.append(mu[i - 1] * p[i - 1] <= mu[i] * p[i])

    objective = cp.Maximize(cp.sum(cp.multiply(w * mu, p)))
    prob = cp.Problem(objective, constraints)

    # Choose a reasonable default solver if none provided.
    # ECOS handles LPs well; OSQP also works (solves QPs; LP is fine).
    chosen = solver or "ECOS"
    prob.solve(solver=chosen, verbose=verbose)

    return np.array(p.value).reshape(-1)


class TestGiniSWF:
    """Tests for the gini_swf function."""

    def test_uniform_weights_is_sum(self):
        """With uniform weights, Gini SWF is just the sum."""
        u = torch.tensor([0.3, 0.1, 0.2], dtype=torch.float64)
        weights = torch.tensor([1.0, 1.0, 1.0], dtype=torch.float64)
        result = gini_swf(u, weights)
        expected = u.sum()
        assert torch.isclose(result, expected, atol=1e-6)

    def test_weighted_sum_of_sorted(self):
        """Gini SWF is weighted sum of ascending-sorted utilities."""
        u = torch.tensor([0.3, 0.1, 0.2], dtype=torch.float64)
        weights = torch.tensor([0.5, 0.3, 0.2], dtype=torch.float64)  # non-increasing
        result = gini_swf(u, weights)
        # sorted ascending: [0.1, 0.2, 0.3]
        # weighted: 0.5*0.1 + 0.3*0.2 + 0.2*0.3 = 0.05 + 0.06 + 0.06 = 0.17
        expected = torch.tensor(0.17, dtype=torch.float64)
        assert torch.isclose(result, expected, atol=1e-6)

    def test_rejects_increasing_weights(self):
        """Gini SWF requires non-increasing weights."""
        u = torch.tensor([0.3, 0.1, 0.2], dtype=torch.float64)
        weights = torch.tensor([0.2, 0.3, 0.5], dtype=torch.float64)  # increasing!
        with pytest.raises(AssertionError):
            gini_swf(u, weights)

    def test_maximin_weights(self):
        """Weights [1, 0, 0, ...] should give minimum utility."""
        u = torch.tensor([0.5, 0.1, 0.3], dtype=torch.float64)
        weights = torch.tensor([1.0, 0.0, 0.0], dtype=torch.float64)
        result = gini_swf(u, weights)
        expected = torch.tensor(0.1, dtype=torch.float64)  # min value
        assert torch.isclose(result, expected, atol=1e-6)

    def test_prefers_equality(self):
        """Higher weights on lower values should prefer more equal distributions."""
        weights = torch.tensor([0.6, 0.3, 0.1], dtype=torch.float64)

        # More equal (same sum of 0.6)
        u_equal = torch.tensor([0.2, 0.2, 0.2], dtype=torch.float64)
        # Less equal (same sum of 0.6)
        u_unequal = torch.tensor([0.5, 0.1, 0.0], dtype=torch.float64)

        swf_equal = gini_swf(u_equal, weights)
        swf_unequal = gini_swf(u_unequal, weights)

        assert swf_equal > swf_unequal


class TestGiniSolverInvariants:
    """Test invariants for GiniSolver."""

    @pytest.fixture
    def solver_setup(self):
        n_arms = 5
        num_alloc = 3
        # Generate non-increasing weights
        weights = torch.softmax(torch.randn(n_arms, dtype=torch.float64), dim=0)
        weights, _ = torch.sort(weights, descending=True)
        return n_arms, num_alloc, weights

    def test_allocation_sums_to_num_alloc(self, solver_setup):
        """Allocation probabilities must sum to num_alloc."""
        n_arms, num_alloc, weights = solver_setup
        solver = GiniSolver(weights, num_alloc)

        u = torch.rand(n_arms, dtype=torch.float64) + 0.1
        probs = solver.get_allocation_probabilities(u)

        assert torch.isclose(
            probs.sum(), torch.tensor(num_alloc, dtype=torch.float64), atol=1e-6
        ), f"Sum={probs.sum()}, expected={num_alloc}"

    def test_probabilities_in_valid_range(self, solver_setup):
        """All probabilities must be in [0, 1]."""
        n_arms, num_alloc, weights = solver_setup
        solver = GiniSolver(weights, num_alloc)

        u = torch.rand(n_arms, dtype=torch.float64) + 0.1
        probs = solver.get_allocation_probabilities(u)

        assert torch.all(probs >= -1e-8), f"Negative probability: {probs.min()}"
        assert torch.all(probs <= 1 + 1e-8), f"Probability > 1: {probs.max()}"

    def test_rejects_increasing_weights(self):
        """Solver should reject non-decreasing weights."""
        weights = torch.tensor([0.2, 0.3, 0.5], dtype=torch.float64)  # increasing
        with pytest.raises(AssertionError):
            GiniSolver(weights, num_alloc=2)


class TestGiniSolverEdgeCases:
    """Test edge cases for GiniSolver."""

    def test_single_arm(self):
        """Single arm gets all allocation."""
        weights = torch.tensor([1.0], dtype=torch.float64)
        solver = GiniSolver(weights, num_alloc=1)

        u = torch.tensor([0.5], dtype=torch.float64)
        probs = solver.get_allocation_probabilities(u)

        assert torch.isclose(probs[0], torch.tensor(1.0, dtype=torch.float64), atol=1e-6)

    def test_uniform_weights_equal_utilities(self):
        """Equal utilities with uniform weights should give equal allocations."""
        n_arms = 4
        weights = torch.ones(n_arms, dtype=torch.float64)
        solver = GiniSolver(weights, num_alloc=2)

        u = torch.ones(n_arms, dtype=torch.float64) * 0.5
        probs = solver.get_allocation_probabilities(u)

        # All should be equal: 2/4 = 0.5
        assert torch.allclose(probs, torch.ones(n_arms, dtype=torch.float64) * 0.5, atol=1e-6)

    def test_compare_solution_with_cvxpy(self):
        """Compare GiniSolver solution with CVXPY LP solution."""
        n_arms = 700
        num_alloc = 5
        seed = 42

        torch.manual_seed(seed)
        u = torch.rand(n_arms, dtype=torch.float64) + 0.1  # avoid zero utilities
        u = torch.sort(u).values  # non-decreasing
        # weights = torch.rand(n_arms, dtype=torch.float64)
        # weights = torch.sort(weights, descending=True).values  # non-increasing
        weights = 0.9 ** torch.arange(n_arms, dtype=torch.float64)  # geometric weights
        weights = weights / weights.sum()

        # Solve via GiniSolver
        solver = GiniSolver(weights, num_alloc)
        probs_gini = solver.get_allocation_probabilities(u)
        val_gini = (weights * u * probs_gini).sum().item()

        # Solve LP via CVXPY
        mu_np = u.numpy()
        w_np = weights.numpy()
        probs_cvxpy = solve_gini_lp(u, weights, num_alloc)
        val_cvxpy = (w_np * mu_np * probs_cvxpy).sum()
        assert np.isclose(val_gini, val_cvxpy, atol=1e-6)
