import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]



# t = (s, b, d)
DEFAULT_T_INIT = torch.tensor([0.0, 0.5, 0.0], dtype=torch.float32)  # [s, b, d]
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)            # agent starts at 0 effort
DEFAULT_PARAMS = dict(
    r=1.0,         # risk aversion
    c=1.0,         # effort cost
    sigma=0.2,     # idiosyncratic noise
    tau=0.2,       # common noise
    v=1.0,         # marginal value of output
    U_res=0.0,     # reservation utility
    a_peer=0.1,    # peer's effort baseline
)


# We do NOT train the transfer s (index 0), and we also ignore it in t-distance plots.
_TLEN = int(DEFAULT_T_INIT.numel())  # 3
T_TRAIN_MASK  = torch.tensor([0.0] + [1.0] * (_TLEN - 1), dtype=torch.float32)  # [0,1,1]
T_METRIC_MASK = torch.tensor([0.0] + [1.0] * (_TLEN - 1), dtype=torch.float32)  # [0,1,1]



def _extract_sbd(t: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Enforce that t encodes (s, b, d) in a length-3 tensor.
    Returns tuple (s, b, d).
    """
    t = t.reshape(-1)
    assert t.numel() == 3, (
        f"Relative-performance setting expects t to have 3 entries (s, b, d). "
        f"Got shape {tuple(t.shape)}."
    )
    return t[0], t[1], t[2]


def project_t_box(t: Tensor) -> Tensor:
    """
    Clamp (b, d) into a safe box; leave s as-is.
      - Theory gives b* ∈ (0,1], d* ≤ 0, so [b ∈ (0,1.5)], [d ∈ (-1.5, 0)] is conservative.
    """
    tf = t.reshape(-1).clone()
    if tf.numel() < 3:
        return tf.view_as(t)
    tf[1] = tf[1].clamp(0.0, 1.5)    # b
    tf[2] = tf[2].clamp(-1.5, 0.0)   # d
    return tf.view_as(t)

def project_t_box_default(t: Tensor) -> Tensor:
    """Driver-friendly alias used by experiment.py."""
    return project_t_box(t)



def u1(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    tau: float,
    v: float,
    a_peer: float = 0.0,
    **kwargs
) -> Tensor:
    """
    Principal's certainty-equivalent for agent i:
        Var(w_i) = (b^2 + d^2)(σ^2 + τ^2) + 2 b d τ^2
        u1 = v a - 0.5 r Var(w_i) - 0.5 c a^2
    where t encodes (s, b, d). (s unused in u1.)
    """
    _, b, d = _extract_sbd(t)
    sig2, tau2 = sigma**2, tau**2
    var_w = (b**2 + d**2) * (sig2 + tau2) + 2.0 * b * d * tau2
    v_t = torch.as_tensor(v, device=a.device, dtype=a.dtype)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    return v_t * a - 0.5 * r_t * var_w - 0.5 * c_t * (a**2)


def u2(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    tau: float,
    a_peer: float = 0.0,
    **kwargs
) -> Tensor:
    """
    Agent's certainty-equivalent:
        E[w_i] = s + b a + d a_peer
        Var(w_i) as above
        u2 = E[w_i] - 0.5 r Var(w_i) - 0.5 c a^2
    where t encodes (s, b, d).
    """
    s, b, d = _extract_sbd(t)
    sig2, tau2 = sigma**2, tau**2
    var_w = (b**2 + d**2) * (sig2 + tau2) + 2.0 * b * d * tau2
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    return s + b * a + d * float(a_peer) - 0.5 * r_t * var_w - 0.5 * c_t * (a**2)


def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Closed-form:
        d* = - b* τ^2 / (σ^2 + τ^2)
        σ_eff^2 = σ^2(σ^2 + 2τ^2)/(σ^2 + τ^2)
        b* = v / (v + r c σ_eff^2)
        a* = b*/c
        s* = U_res - [ b* a* + d* a_peer - 0.5 r Var(w_i) - 0.5 c a*^2 ]

    Returns:
        a_star: shape (1,) tensor
        t_star: shape (3,) tensor with t_star = [s*, b*, d*]
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    r      = float(setting_parameters["r"])
    c      = float(setting_parameters["c"])
    sigma  = float(setting_parameters["sigma"])
    tau    = float(setting_parameters["tau"])
    U_res  = float(setting_parameters["U_res"])
    v      = float(setting_parameters.get("v", 1.0))
    a_peer = float(setting_parameters.get("a_peer", 0.0))

    sig2, tau2 = sigma**2, tau**2
    sigma_eff2 = (sig2 * (sig2 + 2.0 * tau2)) / (sig2 + tau2 + 1e-12)
    b_star_val = v / (v + r * c * sigma_eff2)
    d_star_val = - b_star_val * (tau2 / (sig2 + tau2 + 1e-12))
    a_star_val = b_star_val / c

    var_w = (b_star_val**2 + d_star_val**2) * (sig2 + tau2) + 2.0 * b_star_val * d_star_val * tau2
    s_star_val = U_res - (b_star_val * a_star_val + d_star_val * a_peer
                          - 0.5 * r * var_w - 0.5 * c * a_star_val**2)

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    a_star = torch.tensor([a_star_val], device=device, dtype=dtype)
    t_star = torch.tensor([s_star_val, b_star_val, d_star_val], device=device, dtype=dtype)
    return a_star, t_star


def u1_at_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Union[float, Tensor]:
    """
    Compute u1(a*, t*) at closed-form optimum.
    Returns a scalar Tensor if t is Tensor (matching its device/dtype), else a float.
    """
    a_star, t_star = get_theoretical_optimum(t=t, setting_parameters=setting_parameters)
    _, b_star, d_star = _extract_sbd(t_star)

    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])
    tau = float(setting_parameters["tau"])
    v = float(setting_parameters.get("v", 1.0))

    sig2, tau2 = sigma**2, tau**2
    var_w = (b_star**2 + d_star**2) * (sig2 + tau2) + 2.0 * b_star * d_star * tau2

    if isinstance(a_star, torch.Tensor):
        a0 = a_star.reshape(-1)[0]
        r_t = torch.as_tensor(r, device=a_star.device, dtype=a_star.dtype)
        c_t = torch.as_tensor(c, device=a_star.device, dtype=a_star.dtype)
        v_t = torch.as_tensor(v, device=a_star.device, dtype=a_star.dtype)
        return v_t * a0 - 0.5 * r_t * var_w - 0.5 * c_t * (a0**2)
    else:
        a0 = float(a_star)
        return v * a0 - 0.5 * r * var_w - 0.5 * c * (a0**2)
