# two_signals.py  (revised)
import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional, Sequence

Number = Union[float, int]

def _extract_s_b2(
    t: Tensor,
    *,
    s_index: int = 0,
    b_indices: Optional[Sequence[int]] = None,
    b_start: int = 1
) -> Tuple[Tensor, Tensor, Tensor]:
    """
    Extract s, b1, b2 from arbitrary-shaped t.
    - If b_indices is provided, it must have length 2 and gives flat indices for (b1, b2).
    - Otherwise, we take a contiguous block starting at b_start of length 2.
    """
    t_flat = t.flatten()
    n = t_flat.numel()
    if s_index >= n:
        raise IndexError(f"s_index={s_index} out of bounds for t with {n} elements.")
    s = t_flat[s_index]

    if b_indices is not None:
        if len(b_indices) != 2:
            raise ValueError("b_indices must have length 2 (for b1 and b2).")
        i1, i2 = int(b_indices[0]), int(b_indices[1])
        if i1 < 0 or i2 < 0 or i1 >= n or i2 >= n:
            raise IndexError("Some b_indices are out of bounds for t.")
        b1, b2 = t_flat[i1], t_flat[i2]
    else:
        end = b_start + 2
        if end > n:
            raise IndexError(
                f"Not enough entries in t to read 2 slopes starting at b_start={b_start} "
                f"(t has {n} elements)."
            )
        b1, b2 = t_flat[b_start], t_flat[b_start + 1]
    return s, b1, b2


def u1(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma1: float,
    sigma2: float,
    v: float,
    s_index: int = 0,
    b_indices: Optional[Sequence[int]] = None,
    b_start: int = 1,
    **kwargs
) -> Tensor:
    """
    Principal CE (two independent signals m1, m2):
        u1 = v*a - 0.5*r*(b1^2*sigma1^2 + b2^2*sigma2^2) - 0.5*c*a^2
    Works for arbitrary shapes of `a` and `t`. `s` is unused here (kept for symmetry).
    """
    # read slopes
    _, b1, b2 = _extract_s_b2(t, s_index=s_index, b_indices=b_indices, b_start=b_start)

    # cast params onto a's device/dtype
    v_t = torch.as_tensor(v, device=a.device, dtype=a.dtype)
    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    s1_t = torch.as_tensor(sigma1, device=a.device, dtype=a.dtype)
    s2_t = torch.as_tensor(sigma2, device=a.device, dtype=a.dtype)

    var_term = (b1**2) * (s1_t**2) + (b2**2) * (s2_t**2)
    val = v_t * a - 0.5 * r_t * var_term - 0.5 * c_t * (a**2)
    return val if val.ndim == 0 else val.sum()


def u2(
    a: Tensor,
    t: Tensor,
    *,
    r: float,
    c: float,
    sigma1: float,
    sigma2: float,
    s_index: int = 0,
    b_indices: Optional[Sequence[int]] = None,
    b_start: int = 1,
    **kwargs
) -> Tensor:
    """
    Agent CE with m1=a+ε1, m2=a+ε2 (independent):
        u2 = s + (b1+b2)*a - 0.5*r*(b1^2*sigma1^2 + b2^2*sigma2^2) - 0.5*c*a^2
    Works for arbitrary shapes of `a` and `t`.
    """
    s, b1, b2 = _extract_s_b2(t, s_index=s_index, b_indices=b_indices, b_start=b_start)

    r_t = torch.as_tensor(r, device=a.device, dtype=a.dtype)
    c_t = torch.as_tensor(c, device=a.device, dtype=a.dtype)
    s1_t = torch.as_tensor(sigma1, device=a.device, dtype=a.dtype)
    s2_t = torch.as_tensor(sigma2, device=a.device, dtype=a.dtype)

    var_term = (b1**2) * (s1_t**2) + (b2**2) * (s2_t**2)
    val = s + (b1 + b2) * a - 0.5 * r_t * var_term - 0.5 * c_t * (a**2)
    return val if val.ndim == 0 else val.sum()


def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Union[float, Tensor], Union[torch.Tensor, Tuple[float, float]], Union[float, Tensor]]:
    """
    Closed-form:
      σ_eff^2 = (1/σ1^2 + 1/σ2^2)^{-1}
      β* = b1*+b2* = v / (v + r c σ_eff^2)
      b_i* = β* * (σ_i^{-2}) / (σ1^{-2} + σ2^{-2})
      a* = β*/c
      s* = U_res - [ β* a* - 0.5 r (b1*^2 σ1^2 + b2*^2 σ2^2) - 0.5 c a*^2 ]
    Returns (a*, torch.tensor([b1*, b2*]), s*). If `t` is provided, matches its device/dtype.
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")
    r = float(setting_parameters["r"])
    c = float(setting_parameters["c"])
    v = float(setting_parameters.get("v", 1.0))
    sigma1 = float(setting_parameters["sigma1"])
    sigma2 = float(setting_parameters["sigma2"])
    U_res = float(setting_parameters["U_res"])

    # guards
    if c <= 0:
        raise ValueError("Parameter `c` must be > 0.")
    if sigma1 < 0 or sigma2 < 0:
        raise ValueError("sigma1 and sigma2 must be >= 0.")
    if r < 0:
        raise ValueError("Parameter `r` must be >= 0.")

    inv1 = 1.0 / (sigma1**2 + 1e-12)
    inv2 = 1.0 / (sigma2**2 + 1e-12)
    invsum = inv1 + inv2
    sigma_eff2 = 1.0 / invsum
    beta = v / (v + r * c * sigma_eff2)

    b1_star = beta * (inv1 / invsum)
    b2_star = beta * (inv2 / invsum)
    a_star = beta / c
    var_term = (b1_star**2) * (sigma1**2) + (b2_star**2) * (sigma2**2)
    s_star = U_res - (beta * a_star - 0.5 * r * var_term - 0.5 * c * a_star**2)

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
        a_star = torch.as_tensor(a_star, device=device, dtype=dtype)
        b_star = torch.stack([
            torch.as_tensor(b1_star, device=device, dtype=dtype),
            torch.as_tensor(b2_star, device=device, dtype=dtype),
        ])
        s_star = torch.as_tensor(s_star, device=device, dtype=dtype)
    else:
        b_star = torch.tensor([b1_star, b2_star])

    return a_star, b_star, s_star
