import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]



DEFAULT_T_INIT = torch.tensor([1.0, 1.0], dtype=torch.float32)   # λ=1, μ=1
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)        # agent starts at 0 effort
DEFAULT_PARAMS = dict(
    s=1.0,
    c=1.0,
    a0=0.0,
    w_min=1e-3,
    nsamples=1024,
    rho=1.0,
    theta=0.0,
)


def merge_with_defaults(params: Optional[Dict[str, Number]]) -> Dict[str, Number]:
    """Merge provided params with DEFAULT_PARAMS (defaults take priority if missing)."""
    out = dict(DEFAULT_PARAMS)
    if params:
        out.update(params)
    return out


def sample_laplace_z(nsamples: int, dtype: torch.dtype, device: torch.device) -> Tensor:
    """Draw Laplace(0,1) samples using inverse CDF."""
    u = torch.rand(nsamples, device=device, dtype=dtype).clamp_(1e-7, 1 - 1e-7)
    return torch.where(u <= 0.5, torch.log(2 * u), -torch.log(2 * (1 - u)))


def _unpack_lambda_mu(t: Tensor) -> Tuple[Tensor, Tensor]:
    """Unpack (lambda, mu) from tensor t."""
    t_flat = t.reshape(-1)
    if t_flat.numel() < 2:
        raise ValueError("t must contain at least (lambda, mu).")
    return t_flat[0], t_flat[1]


def wage(x: Tensor, lam: Tensor, mu: Tensor, *, s: float, a0: float, w_min: float) -> Tensor:
    """
    w(x) = lam + mu * sigmoid((x - a0)/s), floored by w_min for log/sqrt safety.
    """
    w = lam + mu * torch.sigmoid((x - a0) / float(s))
    return torch.clamp(w, min=float(w_min))


def u2(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    rho: float = DEFAULT_PARAMS["rho"],
    theta: float = DEFAULT_PARAMS["theta"],
    **kwargs
) -> Tensor:
    """
    Agent utility:
        E[ sigmoid(rho * (w - theta)) ] - 0.5 c a^2,   with X = a + s Z,  Z ~ Laplace(0,1).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_laplace_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    rho_t = torch.as_tensor(rho, device=w.device, dtype=w.dtype)
    th_t  = torch.as_tensor(theta, device=w.device, dtype=w.dtype)
    u_w = torch.sigmoid(rho_t * (w - th_t))
    return u_w.mean() - 0.5 * torch.as_tensor(c, device=a.device, dtype=a.dtype) * (a ** 2)


def u1(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    **kwargs
) -> Tensor:
    """
    Principal utility:
        E[X - w(X)],   with X = a + s Z,  Z ~ Laplace(0,1).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_laplace_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    return (x - w).mean()


# ---------- Projection ----------
def project_t_box(
    t: Tensor,
    lam_bounds: Tuple[float, float],
    mu_bounds: Tuple[float, float]
) -> Tensor:
    """Clamp λ, μ into a safe box."""
    t_flat = t.reshape(-1).clone()
    t_flat[0] = t_flat[0].clamp(float(lam_bounds[0]), float(lam_bounds[1]))
    t_flat[1] = t_flat[1].clamp(float(mu_bounds[0]),  float(mu_bounds[1]))
    return t_flat.view_as(t)


# ---------- Grid search star ----------
def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Approximate (a*, t*) by nested grid search:
      - maximize u2 over a for each (lambda, mu),
      - then pick (lambda, mu) maximizing principal's u1 at that a*.
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    sp = merge_with_defaults(setting_parameters)

    s      = float(sp["s"])
    c      = float(sp["c"])
    a0     = float(sp["a0"])
    w_min  = float(sp["w_min"])
    rho    = float(sp["rho"])
    theta  = float(sp["theta"])

    grid: Dict[str, Number] = dict(sp.get("grid", {}))
    a_cfg = grid.get("a",   (a0 - 6.0 * s, a0 + 6.0 * s, 200))
    l_cfg = grid.get("lam", (w_min,         w_min + 8.0, 100))
    m_cfg = grid.get("mu",  (0.0,           8.0,         100))
    star_nsamples = int(grid.get("nsamples", sp["nsamples"] * 8))

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    amin, amax, n_a = float(a_cfg[0]), float(a_cfg[1]), int(a_cfg[2])
    lmin, lmax, n_l = float(l_cfg[0]), float(l_cfg[1]), int(l_cfg[2])
    mmin, mmax, n_m = float(m_cfg[0]), float(m_cfg[1]), int(m_cfg[2])

    a_grid   = torch.linspace(amin, amax, n_a, device=device, dtype=dtype)
    lam_grid = torch.linspace(lmin, lmax, n_l, device=device, dtype=dtype)
    mu_grid  = torch.linspace(mmin, mmax, n_m, device=device, dtype=dtype)

    z = sample_laplace_z(star_nsamples, dtype=dtype, device=device)
    rho_t = torch.as_tensor(rho, device=device, dtype=dtype)
    th_t  = torch.as_tensor(theta, device=device, dtype=dtype)

    best_u1 = -float("inf")
    best_a  = torch.zeros(1, device=device, dtype=dtype)
    best_t  = torch.zeros(2, device=device, dtype=dtype)

    for lam in lam_grid:
        for mu in mu_grid:
            X = a_grid[None, :] + s * z[:, None]                       # [ns, n_a]
            W = wage(X, lam, mu, s=s, a0=a0, w_min=w_min)              # [ns, n_a]

            # Agent utility grid over a
            u2_vals = torch.sigmoid(rho_t * (W - th_t)).mean(dim=0) - 0.5 * c * (a_grid ** 2)
            idx = int(torch.argmax(u2_vals).item())
            a_star_t = a_grid[idx]

            # Principal objective at that a*
            X_outer = a_star_t + s * z
            W_outer = wage(X_outer, lam, mu, s=s, a0=a0, w_min=w_min)
            u1_val = (X_outer - W_outer).mean().item()

            if u1_val > best_u1:
                best_u1 = u1_val
                best_a  = a_star_t.unsqueeze(0)                         # (1,)
                best_t  = torch.stack([lam, mu])                        # (2,)

    return best_a, best_t
