from __future__ import annotations

import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Tuple


@dataclass
class StrategyInfo:
    strategy_id: str
    metadata: Dict


def alpha_rank_stationary(
    pairwise_payoff: np.ndarray,
    alpha: float = 8,
    num_iters: int = 10_000,
    tol: float = 1e-10,
) -> np.ndarray:
    K = pairwise_payoff.shape[0]
    if K == 0:
        return np.array([])
    if K == 1:
        return np.array([1.0])

    P = np.zeros((K, K), dtype=float)

    for i in range(K):
        row_sum = 0.0
        for j in range(K):
            if i == j:
                continue

            adv = pairwise_payoff[j, i] - pairwise_payoff[i, j]
            adv = np.clip(adv * alpha, -50.0, 50.0)  
            tau_ij = 1.0 / (1.0 + np.exp(-adv))

            P[i, j] = tau_ij / (K - 1)
            row_sum += P[i, j]

        P[i, i] = max(0.0, 1.0 - row_sum)

    pi = np.ones(K, dtype=float) / K
    for _ in range(num_iters):
        new_pi = pi @ P
        s = new_pi.sum()
        if s <= 0:
            new_pi = np.ones(K, dtype=float) / K
        else:
            new_pi = new_pi / s

        if np.linalg.norm(new_pi - pi, ord=1) < tol:
            pi = new_pi
            break
        pi = new_pi

    return pi


class MetaGame:
    """Utility matrix between solver strategies and generator strategies."""

    def __init__(self) -> None:
        self.solver_strategies: List[StrategyInfo] = []
        self.generator_strategies: List[StrategyInfo] = []
        self.utilities: np.ndarray = np.zeros((0, 0), dtype=float)

    def add_solver(self, strategy_id: str, metadata: Dict) -> int:
        self.solver_strategies.append(StrategyInfo(strategy_id, metadata))
        self._resize()
        return len(self.solver_strategies) - 1

    def add_generator(self, strategy_id: str, metadata: Dict) -> int:
        self.generator_strategies.append(StrategyInfo(strategy_id, metadata))
        self._resize()
        return len(self.generator_strategies) - 1

    def remove_solver(self, idx: int) -> None:
        self.solver_strategies.pop(idx)
        self.utilities = np.delete(self.utilities, idx, axis=0)

    def remove_generator(self, idx: int) -> None:
        self.generator_strategies.pop(idx)
        self.utilities = np.delete(self.utilities, idx, axis=1)

    def _resize(self) -> None:
        n_h = len(self.solver_strategies)
        n_g = len(self.generator_strategies)
        old = self.utilities
        self.utilities = np.zeros((n_h, n_g), dtype=float)
        if old.size:
            self.utilities[: old.shape[0], : old.shape[1]] = old

    def set_utility(self, h_idx: int, g_idx: int, value: float) -> None:
        if value == 0.0 and (h_idx > 0 or g_idx > 0):
            print(f"      [MetaGame.set_utility] WARNING: Setting utility ({h_idx}, {g_idx}) = 0.0")
        self.utilities[h_idx, g_idx] = value

    def solve_ne(self) -> Tuple[np.ndarray, np.ndarray]:
        """Solve zero-sum mixed NE via linear programming if available, else uniform.
        """
        n_h, n_g = self.utilities.shape
        if n_h == 0 or n_g == 0:
            return np.array([]), np.array([])
        if n_h == 1 and n_g == 1:
            return np.array([1.0]), np.array([1.0])
        try:
            from scipy.optimize import linprog

            U = self.utilities  # U[i,j] = gap 
            
            # Solver 
            # min v subject to: sum_i x_i * U[i,j] <= v for all j
            c = np.zeros(n_h + 1)
            c[-1] = 1  # minimize v
            
            A_ub = np.hstack([U.T, -np.ones((n_g, 1))])
            b_ub = np.zeros(n_g)
            
            A_eq = np.hstack([np.ones((1, n_h)), np.zeros((1, 1))])
            b_eq = np.ones(1)
            
            bounds = [(0, None)] * n_h + [(None, None)]
            
            res = linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method='highs')
            
            if not res.success:
                print(f" Solver LP did not converge: {res.message}")
                x = np.ones(n_h) / n_h
            else:
                x = res.x[:-1]
                x = np.maximum(x, 0)
                if x.sum() > 1e-12:
                    x = x / x.sum()
                else:
                    x = np.ones(n_h) / n_h
            
            # Generator 
            # min -v subject to: sum_j y_j * (-U)[i,j] <= -v for all i
            # max v subject to: sum_j y_j * U[i,j] >= v
            c2 = np.zeros(n_g + 1)
            c2[-1] = -1  # maximize v → minimize -v
            
            A_ub2 = np.hstack([-U, np.ones((n_h, 1))])
            b_ub2 = np.zeros(n_h)
            
            A_eq2 = np.hstack([np.ones((1, n_g)), np.zeros((1, 1))])
            b_eq2 = np.ones(1)
            
            bounds2 = [(0, None)] * n_g + [(None, None)]
            
            res2 = linprog(c2, A_ub=A_ub2, b_ub=b_ub2, A_eq=A_eq2, b_eq=b_eq2, bounds=bounds2, method='highs')
            
            if not res2.success:
                print(f" Generator LP did not converge: {res2.message}")
                y = np.ones(n_g) / n_g
            else:
                y = res2.x[:-1]
                y = np.maximum(y, 0)
                if y.sum() > 1e-12:
                    y = y / y.sum()
                else:
                    y = np.ones(n_g) / n_g
            
            return x, y
        except Exception as e:
            print(f" Error solving NE: {e}")
            return np.ones(n_h) / n_h, np.ones(n_g) / n_g

    # ==================  Alpha-Rank Meta-Strategy  ================== #

    def _solver_pairwise_matrix(self) -> np.ndarray:
        U = self.utilities  
        n_h, n_g = U.shape
        if n_h <= 1:
            return np.ones((n_h, n_h), dtype=float)

        # score = -gap
        scores = -U
        P = np.zeros((n_h, n_h), dtype=float)

        for i in range(n_h):
            for j in range(n_h):
                if i == j:
                    continue
                better = scores[i, :] > scores[j, :]
                equal = scores[i, :] == scores[j, :]
                win_prob = (better.sum() + 0.5 * equal.sum()) / float(n_g)
                P[i, j] = win_prob

        np.fill_diagonal(P, 0.5)
        return P

    def _generator_pairwise_matrix(self) -> np.ndarray:
        U = self.utilities  
        n_h, n_g = U.shape
        if n_g <= 1:
            return np.ones((n_g, n_g), dtype=float)

        scores = U  
        P = np.zeros((n_g, n_g), dtype=float)

        for i in range(n_g):
            for j in range(n_g):
                if i == j:
                    continue
                better = scores[:, i] > scores[:, j]
                equal = scores[:, i] == scores[:, j]
                win_prob = (better.sum() + 0.5 * equal.sum()) / float(n_h)
                P[i, j] = win_prob

        np.fill_diagonal(P, 0.5)
        return P

    def solve_alpha_rank(
        self,
        alpha: float = 8.0,
        num_iters: int = 10_000,
        tol: float = 1e-10,
        debug: bool = False,
    ) -> Tuple[np.ndarray, np.ndarray]:
        n_h, n_g = self.utilities.shape
        if n_h == 0 or n_g == 0:
            return np.array([]), np.array([])
        if n_h == 1 and n_g == 1:
            return np.array([1.0]), np.array([1.0])

        if debug:
            print(f"     Utilities matrix shape: {self.utilities.shape}")
            print(f"     Utilities matrix:\n{self.utilities}")

        P_h = self._solver_pairwise_matrix()
        if debug:
            print(f"     Solver pairwise matrix P_h:\n{P_h}")
        sigma_h = alpha_rank_stationary(
            P_h,
            alpha=alpha,
            num_iters=num_iters,
            tol=tol,
        )
        if debug:
            print(f"     Solver Alpha-Rank distribution σ_H: {sigma_h}")

        P_g = self._generator_pairwise_matrix()
        if debug:
            print(f"     Generator pairwise matrix P_g:\n{P_g}")
        sigma_g = alpha_rank_stationary(
            P_g,
            alpha=alpha,
            num_iters=num_iters,
            tol=tol,
        )
        if debug:
            print(f"    📊 Generator Alpha-Rank distribution σ_G: {sigma_g}")

        return sigma_h, sigma_g



