import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]

# ======================= Defaults =======================
DEFAULT_T_INIT = torch.tensor([0.5], dtype=torch.float32)   # slope b
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)   # effort a
DEFAULT_PARAMS = dict(
    r=1.0,
    c=1.0,
    sigma=0.1,
)

# Masks: train & metric both include b (only coordinate)
T_TRAIN_MASK  = torch.tensor([1.0], dtype=torch.float32)
T_METRIC_MASK = torch.tensor([1.0], dtype=torch.float32)

# ======================= Projections =======================
def project_t_box(t: Tensor) -> Tensor:
    """
    Clamp slope b into a safe box [0, 2].
    Returns a tensor with the same shape as t.
    """
    t_flat = t.reshape(-1)
    b = t_flat[0].clamp(min=0.0, max=2.0)
    out = t_flat.clone()
    out[0] = b
    return out.view_as(t)

def project_t_box_default(t: Tensor, setting_parameters: Optional[Dict[str, Number]] = None) -> Tensor:
    """
    Driver-friendly wrapper (the optimizer looks for this name).
    """
    return project_t_box(t)

# ======================= Utilities =======================
def _extract_b(t: Tensor) -> Tensor:
    # Enforce that t encodes a single slope b
    t = t.reshape(-1)
    assert t.numel() == 1, (
        f"Milgrom-Holmstrom setting expects t to have 1 entry (b). "
        f"Got shape {tuple(t.shape)}."
    )
    return t[0]

def u1(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    **kwargs
) -> Tensor:
    """
    Principal's expected utility (s suppressed by participation):
        u1(a,t) = a - 0.5 * r * b^2 * sigma^2 - 0.5 * c * a^2
    where t is the incentive slope b (length-1 tensor).
    """
    b = _extract_b(t)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    sigma_t = torch.as_tensor(sigma, device=a.device, dtype=a.dtype)
    return a - 0.5 * r_t * (b**2) * (sigma_t**2) - 0.5 * c_t * (a**2)

def u2(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    **kwargs
) -> Tensor:
    """
    Agent's expected utility (s suppressed by participation):
        u2(a,t) = b * a - 0.5 * r * b^2 * sigma^2 - 0.5 * c * a^2
    """
    b = _extract_b(t)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    sigma_t = torch.as_tensor(sigma, device=a.device, dtype=a.dtype)
    return b * a - 0.5 * r_t * (b**2) * (sigma_t**2) - 0.5 * c_t * (a**2)

def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Closed-form optimum (single-parameter contract slope b):
        b* = 1 / (1 + r c sigma^2)
        a* = b* / c

    Returns:
        a_star: shape (1,) tensor
        t_star: shape (1,) tensor with t_star[0] = b*
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")
    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])

    denom = 1.0 + r * c * sigma**2
    b_star_val = 1.0 / denom
    a_star_val = b_star_val / c

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    a_star = torch.tensor([a_star_val], device=device, dtype=dtype)
    t_star = torch.tensor([b_star_val], device=device, dtype=dtype)
    return a_star, t_star

def u1_at_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Union[float, Tensor]:
    """
    Compute u1(a*, t*) at closed-form optimum.
    Returns a scalar Tensor if t is Tensor (matching its device/dtype), else a float.
    """
    a_star, t_star = get_theoretical_optimum(t=t, setting_parameters=setting_parameters)
    b_star = t_star.reshape(-1)[0]

    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])

    if isinstance(a_star, torch.Tensor):
        a0 = a_star.reshape(-1)[0]
        r_t = torch.as_tensor(r, device=a_star.device, dtype=a_star.dtype)
        c_t = torch.as_tensor(c, device=a_star.device, dtype=a_star.dtype)
        sigma_t = torch.as_tensor(sigma, device=a_star.device, dtype=a_star.dtype)
        return a0 - 0.5 * r_t * (b_star**2) * (sigma_t**2) - 0.5 * c_t * (a0**2)
    else:
        a0 = float(a_star)
        return a0 - 0.5 * r * (b_star**2) * (sigma**2) - 0.5 * c * (a0**2)
