import torch
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]

def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Union[float, torch.Tensor], Union[float, torch.Tensor], Union[float, torch.Tensor]]:
    """
    Closed-form optimum for the quadratic principal–agent model:

        u2(a,t) = s + b*a - 0.5 * r * b^2 * sigma^2 - 0.5 * c * a^2
        u1(a,t) = a         - 0.5 * r * b^2 * sigma^2 - 0.5 * c * a^2

    Participation binds: u2(a*, t*) = U_res.

    Returns:
        (a_star, b_star, s_star) — as torch.Tensors if `t` is a tensor, else floats.
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    # pull params as floats
    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])
    U_res = float(setting_parameters["U_res"])

    # simple guardrails
    if c <= 0:
        raise ValueError("Parameter `c` must be > 0.")
    if sigma < 0:
        raise ValueError("Parameter `sigma` must be >= 0.")
    if r < 0:
        raise ValueError("Parameter `r` must be >= 0.")

    # closed forms
    denom = 1.0 + r * c * (sigma ** 2)
    # denom should be positive with r,c>=0, but add tiny eps for numerical safety
    denom = denom if denom > 0.0 else 1e-12

    b_star = 1.0 / denom
    a_star = b_star / c
    s_star = U_res - (b_star * a_star - 0.5 * r * (b_star ** 2) * (sigma ** 2) - 0.5 * c * (a_star ** 2))

    # match device/dtype if `t` given
    if isinstance(t, torch.Tensor):
        a_star = torch.as_tensor(a_star, device=t.device, dtype=t.dtype)
        b_star = torch.as_tensor(b_star, device=t.device, dtype=t.dtype)
        s_star = torch.as_tensor(s_star, device=t.device, dtype=t.dtype)

    return a_star, b_star, s_star


def u1_min_at_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Union[float, torch.Tensor]:
    """
    Convenience: returns minimal u1 achieved at the theoretical optimum (a*, b*, s*).
    Useful for plotting a dashed horizontal line.
    """
    a_star, b_star, s_star = get_theoretical_optimum(t=t, setting_parameters=setting_parameters)

    # u1(a*, t*) = a* - 0.5 * r * b*^2 * sigma^2 - 0.5 * c * a*^2
    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])

    if isinstance(a_star, torch.Tensor):
        r_t = torch.as_tensor(r, device=a_star.device, dtype=a_star.dtype)
        c_t = torch.as_tensor(c, device=a_star.device, dtype=a_star.dtype)
        sigma_t = torch.as_tensor(sigma, device=a_star.device, dtype=a_star.dtype)
        return a_star - 0.5 * r_t * (b_star ** 2) * (sigma_t ** 2) - 0.5 * c_t * (a_star ** 2)
    else:
        return a_star - 0.5 * r * (b_star ** 2) * (sigma ** 2) - 0.5 * c * (a_star ** 2)
