# imperfect_signal.py
import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]

def _pick_sb(t: Tensor, s_index: int, b_index: int) -> Tuple[Tensor, Tensor]:
    """Extract (s, b) from an arbitrary-shaped t using provided indices."""
    t_flat = t.flatten()
    if s_index >= t_flat.numel() or b_index >= t_flat.numel():
        raise IndexError(f"s_index={s_index} or b_index={b_index} out of bounds for t with {t_flat.numel()} elements.")
    s = t_flat[s_index]
    b = t_flat[b_index]
    return s, b

def u1(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    v: float,
    s_index: int = 0,
    b_index: int = 1,
    **kwargs
) -> Tensor:
    """
    Principal's certainty-equivalent (imperfect signal z = α a + ε, but α only affects u2's mean):
        u1(a,t) = v*a - 0.5 * r * b^2 * sigma^2 - 0.5 * c * a^2

    Notes:
    - Works for arbitrary shapes of `a` and `t`. We extract (s,b) using indices for `t`.
    - Returns a scalar (sums if needed) so the optimizer can take grads.
    """
    _, b = _pick_sb(t, s_index=s_index, b_index=b_index)
    v_t = torch.as_tensor(v, device=a.device, dtype=a.dtype)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    sigma_t = torch.as_tensor(sigma, device=a.device, dtype=a.dtype)

    val = v_t * a - 0.5 * r_t * (b**2) * (sigma_t**2) - 0.5 * c_t * (a**2)
    return val if val.ndim == 0 else val.sum()

def u2(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma: float,
    alpha: float,
    s_index: int = 0,
    b_index: int = 1,
    **kwargs
) -> Tensor:
    """
    Agent's certainty-equivalent with signal z = α a + ε:
        u2(a,t) = s + b*(α a) - 0.5 * r * b^2 * sigma^2 - 0.5 * c * a^2

    Notes:
    - `alpha` may be scalar; it will broadcast against `a`.
    - We only *read* (s,b) from `t` using indices; `t` can be any shape.
    """
    s, b = _pick_sb(t, s_index=s_index, b_index=b_index)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    sigma_t = torch.as_tensor(sigma, device=a.device, dtype=a.dtype)
    alpha_t = torch.as_tensor(alpha, device=a.device, dtype=a.dtype)

    val = s + b * (alpha_t * a) - 0.5 * r_t * (b**2) * (sigma_t**2) - 0.5 * c_t * (a**2)
    return val if val.ndim == 0 else val.sum()

def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Union[float, Tensor], Union[float, Tensor], Union[float, Tensor]]:
    """
    Closed-form (scalar-slope imperfect signal):
        b* = (v α) / (v α^2 + r c sigma^2)
        a* = (b* α) / c
        s* = U_res - [ b* α a* - 0.5 r b*^2 sigma^2 - 0.5 c a*^2 ]

    Returns (a*, b*, s*). If `t` is provided, we match its device/dtype.
    Indices (s_index, b_index) are only for reading in u1/u2, not needed here.
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    # pull params as floats
    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    sigma = float(setting_parameters["sigma"])
    U_res = float(setting_parameters["U_res"])
    alpha = float(setting_parameters["alpha"])
    v = float(setting_parameters.get("v", 1.0))

    # simple guardrails
    if c <= 0:
        raise ValueError("Parameter `c` must be > 0.")
    if sigma < 0:
        raise ValueError("Parameter `sigma` must be >= 0.")
    if r < 0:
        raise ValueError("Parameter `r` must be >= 0.")
    if alpha == 0.0 and v != 0.0:
        # with zero sensitivity, incentives must be zero; handle gracefully
        b_star = 0.0
        a_star = 0.0
        s_star = U_res - (0.0 - 0.0 - 0.0)
    else:
        denom = v * (alpha**2) + r * c * (sigma**2)
        denom = denom if denom > 0.0 else 1e-12
        b_star = (v * alpha) / denom
        a_star = (b_star * alpha) / c
        s_star = U_res - (b_star * alpha * a_star - 0.5 * r * (b_star**2) * (sigma**2) - 0.5 * c * (a_star**2))

    # match device/dtype if `t` given
    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
        a_star = torch.as_tensor(a_star, device=device, dtype=dtype)
        b_star = torch.as_tensor(b_star, device=device, dtype=dtype)
        s_star = torch.as_tensor(s_star, device=device, dtype=dtype)

    return a_star, b_star, s_star
