import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional, List

Number = Union[float, int]

# ======================= Defaults =======================
# By default: K=2 tasks, with slopes b1, b2 plus scalar s
DEFAULT_T_INIT = torch.tensor([0.0, 0.5, 0.5], dtype=torch.float32)   # (s, b1, b2)
DEFAULT_A_INIT = torch.zeros(2, dtype=torch.float32)                  # two tasks, start at 0
DEFAULT_PARAMS = dict(
    r=1.0,
    U_res=0.0,
    c_vec=[1.0, 1.0],       # identical effort costs
    sigma_vec=[0.2, 0.2],   # identical noise
    v_vec=[1.0, 1.0],       # identical values
)

# ======================= Masks =======================
# We typically DO NOT train the transfer s, and we also ignore it in the t-distance metric.
# Only the slopes b_i are trainable / compared.
# For K from DEFAULT_T_INIT: len=3 -> masks [0, 1, 1]
_TLEN = int(DEFAULT_T_INIT.numel())
T_TRAIN_MASK  = torch.tensor([0.0] + [1.0] * (_TLEN - 1), dtype=torch.float32)
T_METRIC_MASK = torch.tensor([0.0] + [1.0] * (_TLEN - 1), dtype=torch.float32)

# ======================= Helpers =======================
def _extract_sb(t: Tensor, K: int) -> Tuple[Tensor, Tensor]:
    """
    Enforce that t encodes (s, b_vec) in a length-(K+1) tensor.
    Returns tuple (s, b) where b has length K.
    """
    t = t.reshape(-1)
    assert t.numel() == K + 1, (
        f"Multitask separable setting expects t to have {K+1} entries (s + {K} slopes). "
        f"Got shape {tuple(t.shape)}."
    )
    return t[0], t[1:]

# ======================= Projection =======================
def project_t_box(t: Tensor) -> Tensor:
    """
    Clamp b-coordinates to a safe box; leave s as-is.
    Economics: with positive {r, c_i, sigma_i^2, v_i}, each b_i* ∈ (0, 1],
    so [0, 1.5] (or [0, 2]) is a conservative box.
    """
    tf = t.reshape(-1).clone()
    if tf.numel() <= 1:
        return tf.view_as(t)
    # keep s (index 0) untouched; clamp b's
    b = tf[1:]
    b = b.clamp(0.0, 1.5)
    tf[1:] = b
    return tf.view_as(t)

def project_t_box_default(t: Tensor) -> Tensor:
    """Driver-friendly wrapper name used by experiment.py."""
    return project_t_box(t)

# ======================= Utilities =======================
def u1(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c_vec: List[Number],
    sigma_vec: List[Number],
    v_vec: List[Number],
    **kwargs
) -> Tensor:
    """
    Principal's certainty-equivalent:
        u1 = sum_i v_i a_i - 0.5 r sum_i b_i^2 sigma_i^2 - 0.5 sum_i c_i a_i^2
    where t encodes (s, b_vec). (s unused in u1.)
    """
    a_vec = a.view(-1)
    K = a_vec.numel()
    _, b = _extract_sb(t, K)

    v = torch.as_tensor(v_vec, device=a.device, dtype=a.dtype)
    c = torch.as_tensor(c_vec, device=a.device, dtype=a.dtype)
    sigma = torch.as_tensor(sigma_vec, device=a.device, dtype=a.dtype)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)

    if not (len(v) == len(c) == len(sigma) == K == b.numel()):
        raise ValueError("Length mismatch between parameters, a, and b.")

    return (v @ a_vec) - 0.5 * r_t * (b**2 * sigma**2).sum() - 0.5 * (c * a_vec**2).sum()

def u2(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c_vec: List[Number],
    sigma_vec: List[Number],
    **kwargs
) -> Tensor:
    """
    Agent's certainty-equivalent:
        u2 = s + sum_i b_i a_i - 0.5 r sum_i b_i^2 sigma_i^2 - 0.5 sum_i c_i a_i^2
    where t encodes (s, b_vec).
    """
    a_vec = a.view(-1)
    K = a_vec.numel()
    s, b = _extract_sb(t, K)

    c = torch.as_tensor(c_vec, device=a.device, dtype=a.dtype)
    sigma = torch.as_tensor(sigma_vec, device=a.device, dtype=a.dtype)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)

    if not (len(c) == len(sigma) == K == b.numel()):
        raise ValueError("Length mismatch between parameters, a, and b.")

    return s + (b @ a_vec) - 0.5 * r_t * (b**2 * sigma**2).sum() - 0.5 * (c * a_vec**2).sum()

# ======================= Closed-form optimum =======================
def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Closed-form optimum (task-by-task, separable):
        b_i* = v_i / (v_i + r c_i sigma_i^2)
        a_i* = b_i* / c_i
        s*   = U_res - [ sum_i b_i* a_i* - 0.5 r sum_i b_i*^2 sigma_i^2 - 0.5 sum_i c_i a_i*^2 ]

    Returns:
        a_star: shape (K,) tensor
        t_star: shape (K+1,) tensor with t_star[0] = s*, t_star[1:] = b*
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    r = float(setting_parameters["r"])
    U_res = float(setting_parameters["U_res"])
    c_list: List[Number] = list(setting_parameters["c_vec"])
    sigma_list: List[Number] = list(setting_parameters["sigma_vec"])
    v_list: List[Number] = list(setting_parameters["v_vec"])

    if not (len(c_list) == len(sigma_list) == len(v_list)):
        raise ValueError("c_vec, sigma_vec, and v_vec must have the same length.")

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    c = torch.as_tensor(c_list, device=device, dtype=dtype)
    sigma2 = torch.as_tensor([float(s)**2 for s in sigma_list], device=device, dtype=dtype)
    v = torch.as_tensor(v_list, device=device, dtype=dtype)

    denom = v + r * c * sigma2
    denom = torch.where(denom > 0, denom, torch.full_like(denom, 1e-12))

    b_star = v / denom
    a_star = b_star / c
    s_star = U_res - ((b_star * a_star).sum()
                      - 0.5 * r * (b_star**2 * sigma2).sum()
                      - 0.5 * (c * a_star**2).sum())

    a_star = a_star.to(dtype=dtype, device=device)
    b_star = b_star.to(dtype=dtype, device=device)
    s_star = torch.as_tensor(s_star, dtype=dtype, device=device)

    t_star = torch.cat([s_star.view(1), b_star])
    return a_star, t_star

def u1_at_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Union[float, Tensor]:
    """
    Compute u1(a*, t*) at closed-form optimum.
    Returns a scalar Tensor if t is Tensor (matching its device/dtype), else a float.
    """
    a_star, t_star = get_theoretical_optimum(t=t, setting_parameters=setting_parameters)
    _, b_star = _extract_sb(t_star, K=a_star.numel())

    r = float(setting_parameters["r"])
    c_list: List[Number] = list(setting_parameters["c_vec"])
    sigma_list: List[Number] = list(setting_parameters["sigma_vec"])
    v_list: List[Number] = list(setting_parameters["v_vec"])

    if isinstance(a_star, torch.Tensor):
        a_vec = a_star
        v = torch.as_tensor(v_list, device=a_vec.device, dtype=a_vec.dtype)
        c = torch.as_tensor(c_list, device=a_vec.device, dtype=a_vec.dtype)
        sigma = torch.as_tensor(sigma_list, device=a_vec.device, dtype=a_vec.dtype)
        r_t = torch.as_tensor(r, device=a_vec.device, dtype=a_vec.dtype)
        return (v @ a_vec) - 0.5 * r_t * (b_star**2 * sigma**2).sum() - 0.5 * (c * a_vec**2).sum()
    else:
        a_vec = torch.as_tensor(a_star, dtype=torch.float32)
        b_star_vec = torch.as_tensor(b_star, dtype=torch.float32)
        v = torch.as_tensor(v_list, dtype=torch.float32)
        c = torch.as_tensor(c_list, dtype=torch.float32)
        sigma = torch.as_tensor(sigma_list, dtype=torch.float32)
        return (v @ a_vec) - 0.5 * r * (b_star_vec**2 * sigma**2).sum().item() - 0.5 * (c * a_vec**2).sum().item()
