import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]


DEFAULT_T_INIT = torch.tensor([0.0, 0.5], dtype=torch.float32)  # (s, b)
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)       # single effort variable
DEFAULT_PARAMS = dict(
    r=1.0,
    c=1.0,
    sigma=0.5,
    ell=1.0,
    U_res=0.0,
)

_TLEN = int(DEFAULT_T_INIT.numel())  # 2
T_TRAIN_MASK  = torch.tensor([0.0] + [1.0] * (_TLEN - 1), dtype=torch.float32)  # [0, 1]
T_METRIC_MASK = torch.tensor([0.0] + [1.0] * (_TLEN - 1), dtype=torch.float32)  # [0, 1]


def _extract_sb(t: Tensor) -> Tuple[Tensor, Tensor]:
    """
    Enforce that t encodes (s, b) in a length-2 tensor.
    Returns tuple (s, b).
    """
    t = t.reshape(-1)
    assert t.numel() == 2, (
        f"Insurance-prevention setting expects t to have 2 entries (s, b). "
        f"Got shape {tuple(t.shape)}."
    )
    return t[0], t[1]


def project_t_box(t: Tensor) -> Tensor:
    """
    Clamp the coinsurance slope b into a safe box [0, 1]; leave s as-is.

    Theory gives b* = (r c σ^2)/(1 + r c σ^2) ∈ [0, 1), so [0, 1] is a principled bound.
    """
    tf = t.reshape(-1).clone()
    if tf.numel() < 2:
        return tf.view_as(t)
    tf[1] = tf[1].clamp(0.0, 1.0)  # clamp b only
    return tf.view_as(t)

def project_t_box_default(t: Tensor) -> Tensor:
    """Driver-friendly alias used by experiment.py."""
    return project_t_box(t)


def u1(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    ell: float,
    **kwargs
) -> Tensor:
    """
    Principal's certainty-equivalent (insurer):
        u1(a,t) = -(ℓ - a) - 0.5 r (1-b)^2 sigma^2 - 0.5 c a^2
    where t encodes (s, b). (s unused in u1.)
    """
    _, b = _extract_sb(t)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    sigma_t = torch.as_tensor(sigma, device=a.device, dtype=a.dtype)
    ell_t = torch.as_tensor(ell, device=a.device, dtype=a.dtype)

    return -(ell_t - a) - 0.5 * r_t * (1.0 - b)**2 * (sigma_t**2) - 0.5 * c_t * (a**2)


def u2(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    ell: float,
    **kwargs
) -> Tensor:
    """
    Agent's certainty-equivalent:
        u2(a,t) = -(1-b)(ℓ - a) - s - 0.5 r (1-b)^2 sigma^2 - 0.5 c a^2
    where t encodes (s, b).
    """
    s, b = _extract_sb(t)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    sigma_t = torch.as_tensor(sigma, device=a.device, dtype=a.dtype)
    ell_t = torch.as_tensor(ell, device=a.device, dtype=a.dtype)

    return -(1.0 - b) * (ell_t - a) - s - 0.5 * r_t * (1.0 - b)**2 * (sigma_t**2) - 0.5 * c_t * (a**2)


def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Closed-form optimum (insurance with prevention):
        b* = (r c sigma^2) / (1 + r c sigma^2)
        a* = (1 - b*) / c = 1 / (c (1 + r c sigma^2))
        s* = U_res + (1-b*)(ℓ - a*) + 0.5 r (1-b*)^2 sigma^2 + 0.5 c (a*)^2

    Returns:
        a_star: shape (1,) tensor
        t_star: shape (2,) tensor with t_star[0] = s*, t_star[1] = b*
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])
    U_res = float(setting_parameters["U_res"])
    ell = float(setting_parameters["ell"])

    if c <= 0:
        raise ValueError("Parameter `c` must be > 0.")
    if sigma < 0:
        raise ValueError("Parameter `sigma` must be >= 0.")
    if r < 0:
        raise ValueError("Parameter `r` must be >= 0.")

    denom = 1.0 + r * c * (sigma**2)
    denom = denom if denom > 0.0 else 1e-12

    b_star_val = (r * c * (sigma**2)) / denom
    a_star_val = (1.0 - b_star_val) / c
    s_star_val = (
        U_res
        + (1.0 - b_star_val) * (ell - a_star_val)
        + 0.5 * r * (1.0 - b_star_val)**2 * (sigma**2)
        + 0.5 * c * (a_star_val**2)
    )

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    a_star = torch.tensor([a_star_val], device=device, dtype=dtype)
    t_star = torch.tensor([s_star_val, b_star_val], device=device, dtype=dtype)
    return a_star, t_star


def u1_at_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Union[float, Tensor]:
    """
    Compute u1(a*, t*) at closed-form optimum.
    Returns a scalar Tensor if t is Tensor (matching its device/dtype), else a float.
    """
    a_star, t_star = get_theoretical_optimum(t=t, setting_parameters=setting_parameters)
    _, b_star = _extract_sb(t_star)

    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])
    ell = float(setting_parameters["ell"])

    if isinstance(a_star, torch.Tensor):
        a0 = a_star.reshape(-1)[0]
        r_t = torch.as_tensor(r, device=a_star.device, dtype=a_star.dtype)
        c_t = torch.as_tensor(c, device=a_star.device, dtype=a_star.dtype)
        sigma_t = torch.as_tensor(sigma, device=a_star.device, dtype=a_star.dtype)
        ell_t = torch.as_tensor(ell, device=a_star.device, dtype=a_star.dtype)
        return -(ell_t - a0) - 0.5 * r_t * (1.0 - b_star)**2 * (sigma_t**2) - 0.5 * c_t * (a0**2)
    else:
        a0 = float(a_star)
        return -(ell - a0) - 0.5 * r * (1.0 - b_star)**2 * (sigma**2) - 0.5 * c * (a0**2)
