# relative_performance.py
import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]

def u1(a: Tensor, t: Tensor, *, r: float, c: float, sigma: float, tau: float, v: float, a_peer: float = 0.0, **kwargs) -> Tensor:
    """
    Principal CE for agent i:
        w_i = s + b*y_i + d*y_j,  y_i = a + η + ε_i, y_j = a_peer + η + ε_j
        Var(w_i) = (b^2 + d^2)*(sigma^2 + tau^2) + 2 b d * tau^2
        u1 = v*a - 0.5*r*Var(w_i) - 0.5*c*a^2
    """
    s, b, d = t[:3]
    sig2 = sigma**2; tau2 = tau**2
    var_w = (b**2 + d**2) * (sig2 + tau2) + 2.0 * b * d * tau2
    return (v * a - 0.5 * r * var_w - 0.5 * c * a**2).sum()

def u2(a: Tensor, t: Tensor, *, r: float, c: float, sigma: float, tau: float, a_peer: float = 0.0, **kwargs) -> Tensor:
    """
    Agent CE:
        E[w_i] = s + b*a + d*a_peer
        Var(w_i) as above
        u2 = E[w_i] - 0.5*r*Var(w_i) - 0.5*c*a^2
    """
    s, b, d = t[:3]
    sig2 = sigma**2; tau2 = tau**2
    var_w = (b**2 + d**2) * (sig2 + tau2) + 2.0 * b * d * tau2
    return (s + b * a + d * float(a_peer) - 0.5 * r * var_w - 0.5 * c * a**2).sum()

def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Union[float, Tensor], Union[Tensor, float], Union[float, Tensor], Union[float, Tensor]]:
    """
    Closed-form:
        d* = - b* * tau^2 / (sigma^2 + tau^2)
        σ_eff^2 = σ^2(σ^2 + 2τ^2)/(σ^2 + τ^2)
        b* = v / (v + r c σ_eff^2)
        a* = b*/c
        s* = U_res - [ E[w_i] - 0.5 r Var(w_i) - 0.5 c a*^2 ], with E[w_i]=s + b*a* + d*a_peer
    Returns (a*, torch.tensor([b*, d*]), s*, d* separately for convenience).
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")
    r = float(setting_parameters["r"]); c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"]); tau = float(setting_parameters["tau"])
    U_res = float(setting_parameters["U_res"]); v = float(setting_parameters.get("v",1.0))
    a_peer = float(setting_parameters.get("a_peer", 0.0))

    sig2, tau2 = sigma**2, tau**2
    sigma_eff2 = (sig2 * (sig2 + 2.0 * tau2)) / (sig2 + tau2 + 1e-12)
    b_star = v / (v + r * c * sigma_eff2)
    d_star = - b_star * (tau2 / (sig2 + tau2 + 1e-12))
    a_star = b_star / c

    # Var at optimum
    var_w = (b_star**2 + d_star**2) * (sig2 + tau2) + 2.0 * b_star * d_star * tau2
    # Solve for s* from participation
    s_star = U_res - (b_star * a_star + d_star * a_peer - 0.5 * r * var_w - 0.5 * c * a_star**2)

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
        a_star = torch.as_tensor(a_star, device=device, dtype=dtype)
        bd_star = torch.stack([torch.as_tensor(b_star, device=device, dtype=dtype),
                               torch.as_tensor(d_star, device=device, dtype=dtype)])
        s_star = torch.as_tensor(s_star, device=device, dtype=dtype)
    else:
        bd_star = torch.tensor([b_star, d_star])
    return a_star, bd_star, s_star, d_star
