import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]


# ======================= Defaults =======================
DEFAULT_T_INIT = torch.tensor([1.0, 1.0], dtype=torch.float32)   # λ=1, μ=1
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)        # agent starts at 0 effort
DEFAULT_PARAMS = dict(
    s=1.0,
    c=1.0,
    a0=0.0,
    w_min=1e-2,      # slightly higher floor for sqrt stability
    nsamples=1024,
)


def merge_with_defaults(params: Optional[Dict[str, Number]]) -> Dict[str, Number]:
    """Merge provided params with DEFAULT_PARAMS (defaults take priority if missing)."""
    out = dict(DEFAULT_PARAMS)
    if params:
        out.update(params)
    return out


# -------- Sampling --------
def sample_logistic_z(nsamples: int, dtype: torch.dtype, device: torch.device) -> Tensor:
    """Draw Logistic(0,1) samples using inverse CDF."""
    u = torch.rand(nsamples, device=device, dtype=dtype).clamp_(1e-7, 1 - 1e-7)
    return torch.log(u) - torch.log1p(-u)


# -------- Contract helpers --------
def _unpack_lambda_mu(t: Tensor) -> Tuple[Tensor, Tensor]:
    """Unpack (lambda, mu) from tensor t."""
    t_flat = t.reshape(-1)
    if t_flat.numel() < 2:
        raise ValueError("t must contain at least (lambda, mu).")
    return t_flat[0], t_flat[1]


def wage(x: Tensor, lam: Tensor, mu: Tensor, *, s: float, a0: float, w_min: float) -> Tensor:
    """
    w(x) = lam + mu * sigmoid((x - a0)/s), floored at w_min to keep sqrt/log safe.
    """
    w = lam + mu * torch.sigmoid((x - a0) / float(s))
    return torch.clamp(w, min=float(w_min))


# -------- Utilities --------
def u2(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    **kw
) -> Tensor:
    """
    Agent utility: E[sqrt(w_t(X)) | a] - 0.5 c a^2, with X = a + s Z, Z ~ Logistic(0,1).
    (Not globally concave in a; OK for experiments.)
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_logistic_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    return torch.sqrt(w).mean() - 0.5 * torch.as_tensor(c, device=a.device, dtype=a.dtype) * (a ** 2)


def u1(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,                      # kept for signature consistency
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    **kw
) -> Tensor:
    """
    Principal utility: E[X - w_t(X) | a], with X = a + s Z, Z ~ Logistic(0,1).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_logistic_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    return (x - w).mean()


# -------- Projection (box) --------
def project_t_box(
    t: Tensor,
    lam_bounds: Tuple[float, float],
    mu_bounds: Tuple[float, float]
) -> Tensor:
    """Clamp λ, μ into a safe box."""
    t_flat = t.reshape(-1).clone()
    t_flat[0] = t_flat[0].clamp(float(lam_bounds[0]), float(lam_bounds[1]))
    t_flat[1] = t_flat[1].clamp(float(mu_bounds[0]),  float(mu_bounds[1]))
    return t_flat.view_as(t)


# -------- Grid-search “theoretical” optimum (approximate) --------
def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Returns (a_star, t_star) with shapes:
      a_star: (1,)
      t_star: (2,)  [lambda, mu]
    For each (λ, μ) on a grid, pick a*(t) = argmax_a u2(a, t) (discrete grid),
    then choose (λ, μ) maximizing u1(a*(t), t).
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    sp = merge_with_defaults(setting_parameters)

    s = float(sp["s"])
    c = float(sp["c"])
    a0 = float(sp["a0"])
    w_min = float(sp["w_min"])

    grid = dict(sp.get("grid", {}))
    a_cfg = grid.get("a",   (a0 - 6.0 * s, a0 + 6.0 * s, 200))
    l_cfg = grid.get("lam", (w_min,        w_min + 8.0,   100))
    m_cfg = grid.get("mu",  (0.0,          8.0,           100))
    star_nsamples = int(grid.get("nsamples", sp["nsamples"] * 8))

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    a_grid  = torch.linspace(float(a_cfg[0]), float(a_cfg[1]), int(a_cfg[2]), device=device, dtype=dtype)
    lam_g   = torch.linspace(float(l_cfg[0]), float(l_cfg[1]), int(l_cfg[2]), device=device, dtype=dtype)
    mu_g    = torch.linspace(float(m_cfg[0]), float(m_cfg[1]), int(m_cfg[2]), device=device, dtype=dtype)
    z       = sample_logistic_z(star_nsamples, dtype=dtype, device=device)

    best_u1 = -float("inf")
    best_a  = torch.zeros(1, device=device, dtype=dtype)
    best_t  = torch.zeros(2, device=device, dtype=dtype)

    for lam in lam_g:
        for mu in mu_g:
            X = a_grid[None, :] + s * z[:, None]                     # [nsamples, n_a]
            W = wage(X, lam, mu, s=s, a0=a0, w_min=w_min)            # [nsamples, n_a]

            u2_vals = torch.sqrt(W).mean(dim=0) - 0.5 * c * (a_grid ** 2)  # [n_a]
            idx = int(torch.argmax(u2_vals).item())
            a_star_t = a_grid[idx]                                    # scalar tensor

            # Outer objective at a_star_t
            X_outer = a_star_t + s * z
            W_outer = wage(X_outer, lam, mu, s=s, a0=a0, w_min=w_min)
            u1_val = (X_outer - W_outer).mean().item()

            if u1_val > best_u1:
                best_u1 = u1_val
                best_a = a_star_t.unsqueeze(0)                        # shape (1,)
                best_t = torch.stack([lam, mu])                       # shape (2,)

    return best_a, best_t
