import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]


# ======================= Defaults =======================
DEFAULT_T_INIT = torch.tensor([1.0, 1.0], dtype=torch.float32)   # λ=1, μ=1
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)        # agent starts at 0 effort
DEFAULT_PARAMS = dict(
    c=1.0,
    a0=0.0,
    w_min=1e-3,
    nsamples=1024,
    rho=1.0,
)


def merge_with_defaults(params: Optional[Dict[str, Number]]) -> Dict[str, Number]:
    """Merge provided params with DEFAULT_PARAMS."""
    out = dict(DEFAULT_PARAMS)
    if params:
        out.update(params)
    return out


# -------- Sampling --------
def sample_normal_z(nsamples: int, dtype: torch.dtype, device: torch.device) -> Tensor:
    """Draw Z ~ N(0,1)."""
    return torch.randn(nsamples, device=device, dtype=dtype)


# -------- Contract helpers --------
def _unpack_lambda_mu(t: Tensor):
    t_flat = t.flatten()
    if t_flat.numel() < 2:
        raise ValueError("t must contain at least (lambda, mu).")
    return t_flat[0], t_flat[1]


def wage(x: Tensor, lam: Tensor, mu: Tensor, *, s_temp: float, a0: float, w_min: float) -> Tensor:
    """Contract w(x) = λ + μ σ((x-a0)/s_temp), floored at w_min."""
    return torch.clamp(lam + mu * torch.sigmoid((x - a0) / max(1.0, float(s_temp))), min=float(w_min))


# -------- Utilities --------
def u2(
    a: Tensor,
    t: Tensor,
    *,
    c: float,
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    rho: float = DEFAULT_PARAMS["rho"],
    s: float = 1.0,
    a0: float = DEFAULT_PARAMS["a0"],
    **kw
) -> Tensor:
    """
    CARA agent with Poisson output.  
    Reparameterized Normal approximation:
      X ≈ m + sqrt(m) * Z,  with m = exp(a).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_normal_z(nsamples, dtype=a.dtype, device=a.device)

    m = torch.exp(a)
    x = m + torch.sqrt(torch.clamp(m, min=1e-8)) * z
    w = wage(x, lam, mu, s_temp=1.0, a0=a0, w_min=w_min)

    rho_t = torch.as_tensor(rho, device=w.device, dtype=w.dtype)
    u_w = -torch.exp(-rho_t * w)

    return u_w.mean() - 0.5 * torch.as_tensor(c, device=a.device, dtype=a.dtype) * (a ** 2)


def u1(
    a: Tensor,
    t: Tensor,
    *,
    c: float,
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    s: float = 1.0,
    a0: float = DEFAULT_PARAMS["a0"],
    **kw
) -> Tensor:
    """Principal utility: E[X - w(X)], X ≈ exp(a) + sqrt(exp(a)) Z."""
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_normal_z(nsamples, dtype=a.dtype, device=a.device)

    m = torch.exp(a)
    x = m + torch.sqrt(torch.clamp(m, min=1e-8)) * z
    w = wage(x, lam, mu, s_temp=1.0, a0=a0, w_min=w_min)
    return (x - w).mean()


# -------- Projection --------
def project_t_box(t: Tensor, lam_bounds, mu_bounds) -> Tensor:
    """Clamp (λ, μ) into provided bounds."""
    tf = t.flatten().clone()
    tf[0] = tf[0].clamp(float(lam_bounds[0]), float(lam_bounds[1]))
    tf[1] = tf[1].clamp(float(mu_bounds[0]),  float(mu_bounds[1]))
    return tf.view_as(t)


# -------- Grid-search “theoretical” optimum --------
def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Optional[Dict[str, Number]] = None,
):
    """
    Returns (a*, t*) with shapes:
      a*: (1,)
      t*: (2,)  [λ, μ]
    Grid search:
      - maximize agent's u2 over a for each (λ, μ),
      - then pick (λ, μ) maximizing principal's u1.
    """
    sp = merge_with_defaults(setting_parameters)

    c, a0, w_min = float(sp["c"]), float(sp["a0"]), float(sp["w_min"])
    rho = float(sp["rho"])

    grid = dict(sp.get("grid", {}))
    a_cfg = grid.get("a", (a0 - 6.0, a0 + 6.0, 200))  # tighter around typical actions
    l_cfg = grid.get("lam", (w_min, w_min + 8.0, 100))
    m_cfg = grid.get("mu",  (0.0,   8.0,          100))
    star_nsamples = int(grid.get("nsamples", sp["nsamples"] * 8))

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    a_grid   = torch.linspace(float(a_cfg[0]), float(a_cfg[1]), int(a_cfg[2]), device=device, dtype=dtype)
    lam_grid = torch.linspace(float(l_cfg[0]), float(l_cfg[1]), int(l_cfg[2]), device=device, dtype=dtype)
    mu_grid  = torch.linspace(float(m_cfg[0]), float(m_cfg[1]), int(m_cfg[2]), device=device, dtype=dtype)
    z        = sample_normal_z(star_nsamples, dtype=dtype, device=device)

    best_u1 = -float("inf")
    best_a  = torch.tensor([0.0], device=device, dtype=dtype)
    best_t  = torch.zeros(2, device=device, dtype=dtype)

    for lam in lam_grid:
        for mu in mu_grid:
            m = torch.exp(a_grid)[None, :]  # [1, n_a]
            X = m + torch.sqrt(torch.clamp(m, min=1e-8)) * z[:, None]
            W = wage(X, lam, mu, s_temp=1.0, a0=a0, w_min=w_min)

            u2_vals = (-torch.exp(-1.0 * W)).mean(0) - 0.5 * c * (a_grid ** 2)
            idx = int(torch.argmax(u2_vals))
            a_star_t = a_grid[idx]

            m_star = torch.exp(a_star_t)
            X_star = m_star + torch.sqrt(torch.clamp(m_star, min=1e-8)) * z
            u1_val = (X_star - wage(X_star, lam, mu, s_temp=1.0, a0=a0, w_min=w_min)).mean().item()

            if u1_val > best_u1:
                best_u1 = u1_val
                best_a  = a_star_t.unsqueeze(0)
                best_t  = torch.stack([lam, mu])

    return best_a, best_t
