# multitask_separable.py  (revised)
import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional, Sequence, List

Number = Union[float, int]

def _extract_s_b(
    t: Tensor,
    a: Tensor,
    *,
    s_index: int = 0,
    b_indices: Optional[Sequence[int]] = None,
    b_start: int = 1
) -> Tuple[Tensor, Tensor]:
    """
    Extract scalar s and vector b from arbitrary-shaped t.
    - If b_indices is provided, it is used directly.
    - Else we infer K = a.numel() and take a contiguous block t_flat[b_start : b_start+K].
    """
    t_flat = t.flatten()
    if s_index >= t_flat.numel():
        raise IndexError(f"s_index={s_index} out of bounds for t with {t_flat.numel()} elements.")

    s = t_flat[s_index]

    if b_indices is not None:
        idx = torch.as_tensor(list(b_indices), dtype=torch.long, device=t.device)
        if (idx < 0).any() or (idx >= t_flat.numel()).any():
            raise IndexError("Some b_indices are out of bounds for t.")
        b = t_flat.index_select(0, idx)
    else:
        K = int(a.numel())
        end = b_start + K
        if end > t_flat.numel():
            raise IndexError(
                f"Not enough entries in t to read {K} slopes starting at b_start={b_start} "
                f"(t has {t_flat.numel()} elements)."
            )
        b = t_flat[b_start:end]

    return s, b


def u1(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c_vec: Sequence[Number],
    sigma_vec: Sequence[Number],
    v_vec: Sequence[Number],
    s_index: int = 0,
    b_indices: Optional[Sequence[int]] = None,
    b_start: int = 1,
    **kwargs
) -> Tensor:
    """
    Principal CE:
        u1 = sum_i v_i a_i - 0.5*r*sum_i b_i^2*sigma_i^2 - 0.5*sum_i c_i a_i^2

    Notes:
    - Works for arbitrary shapes of a and t. We read b using either `b_indices` or a contiguous
      block starting at `b_start` with length K = a.numel().
    - `s` is not used in u1 (kept for symmetry).
    """
    # extract (s, b); s unused in u1
    _, b = _extract_s_b(t, a, s_index=s_index, b_indices=b_indices, b_start=b_start)

    a_vec = a.view(-1)
    v = torch.as_tensor(v_vec, device=a.device, dtype=a.dtype)
    c = torch.as_tensor(c_vec, device=a.device, dtype=a.dtype)
    sigma = torch.as_tensor(sigma_vec, device=a.device, dtype=a.dtype)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)

    if not (len(v) == len(c) == len(sigma) == a_vec.numel() == b.numel()):
        raise ValueError(
            f"Length mismatch: len(v)={len(v)}, len(c)={len(c)}, len(sigma)={len(sigma)}, "
            f"K(a)={a_vec.numel()}, K(b)={b.numel()}"
        )

    var_term = (b**2 * sigma**2).sum()
    val = (v @ a_vec) - 0.5 * r_t * var_term - 0.5 * (c * a_vec**2).sum()
    return val if val.ndim == 0 else val.sum()


def u2(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c_vec: Sequence[Number],
    sigma_vec: Sequence[Number],
    s_index: int = 0,
    b_indices: Optional[Sequence[int]] = None,
    b_start: int = 1,
    **kwargs
) -> Tensor:
    """
    Agent CE:
        u2 = s + sum_i b_i a_i - 0.5*r*sum_i b_i^2*sigma_i^2 - 0.5*sum_i c_i a_i^2

    Notes:
    - Works for arbitrary shapes of a and t. Extracts (s,b) with the same helper.
    """
    s, b = _extract_s_b(t, a, s_index=s_index, b_indices=b_indices, b_start=b_start)

    a_vec = a.view(-1)
    c = torch.as_tensor(c_vec, device=a.device, dtype=a.dtype)
    sigma = torch.as_tensor(sigma_vec, device=a.device, dtype=a.dtype)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)

    if not (len(c) == len(sigma) == a_vec.numel() == b.numel()):
        raise ValueError(
            f"Length mismatch: len(c)={len(c)}, len(sigma)={len(sigma)}, "
            f"K(a)={a_vec.numel()}, K(b)={b.numel()}"
        )

    var_term = (b**2 * sigma**2).sum()
    val = s + (b @ a_vec) - 0.5 * r_t * var_term - 0.5 * (c * a_vec**2).sum()
    return val if val.ndim == 0 else val.sum()


def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Union[Tensor, float], Union[Tensor, float], Union[Tensor, float]]:
    """
    Closed-form (task-by-task, separable):
        b_i* = v_i / (v_i + r c_i sigma_i^2)
        a_i* = b_i* / c_i
        s*   = U_res - [ sum_i b_i* a_i* - 0.5 r sum_i b_i*^2 sigma_i^2 - 0.5 sum_i c_i a_i*^2 ]

    Returns:
        (a*_vec, b*_vec, s*)
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")
    r = float(setting_parameters["r"])
    U_res = float(setting_parameters["U_res"])
    c_list: List[Number] = list(setting_parameters["c_vec"])
    sigma_list: List[Number] = list(setting_parameters["sigma_vec"])
    v_list: List[Number] = list(setting_parameters["v_vec"])

    if not (len(c_list) == len(sigma_list) == len(v_list)):
        raise ValueError("c_vec, sigma_vec, and v_vec must have the same length.")

    device = t.device if isinstance(t, torch.Tensor) else None
    dtype = t.dtype if isinstance(t, torch.Tensor) else None

    c = torch.as_tensor(c_list, device=device, dtype=dtype)
    sig2 = torch.as_tensor([float(s)**2 for s in sigma_list], device=device, dtype=dtype)
    v = torch.as_tensor(v_list, device=device, dtype=dtype)

    denom = v + r * c * sig2
    # guard against degenerate nonpositive denom
    denom = torch.where(denom > 0, denom, torch.full_like(denom, 1e-12))

    b_star = v / denom
    a_star = b_star / c
    s_star = U_res - ( (b_star * a_star).sum() - 0.5 * r * (b_star**2 * sig2).sum() - 0.5 * (c * a_star**2).sum() )

    return a_star, b_star, s_star
