import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]


# ======================= Defaults =======================
DEFAULT_T_INIT = torch.tensor([1.0, 1.0], dtype=torch.float32)   # λ=1, μ=1
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)        # agent starts at 0 effort
DEFAULT_PARAMS = dict(
    s=1.0,
    c=1.0,
    a0=0.0,
    w_min=1e-3,
    nsamples=1024,
    gamma=1.1,   # CRRA parameter
)


def merge_with_defaults(params: Optional[Dict[str, Number]]) -> Dict[str, Number]:
    """Merge provided params with DEFAULT_PARAMS (defaults take priority if missing)."""
    out = dict(DEFAULT_PARAMS)
    if params:
        out.update(params)
    return out


# -------- Sampling --------
def sample_logistic_z(nsamples: int, dtype: torch.dtype, device: torch.device) -> Tensor:
    """Draw Logistic(0,1) samples via inverse CDF."""
    u = torch.rand(nsamples, device=device, dtype=dtype).clamp_(1e-7, 1 - 1e-7)
    return torch.log(u) - torch.log1p(-u)


# -------- Contract helpers --------
def _unpack_lambda_mu(t: Tensor):
    """Unpack (λ, μ) from tensor t."""
    t_flat = t.flatten()
    if t_flat.numel() < 2:
        raise ValueError("t must contain at least (lambda, mu).")
    return t_flat[0], t_flat[1]


def wage(x: Tensor, lam: Tensor, mu: Tensor, *, s_temp: float, a0: float, w_min: float) -> Tensor:
    """Contract: w(x) = λ + μ σ((x-a0)/s_temp), floored at w_min."""
    return torch.clamp(lam + mu * torch.sigmoid((x - a0) / float(s_temp)), min=float(w_min))


# -------- Utilities --------
def u2(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    gamma: float = DEFAULT_PARAMS["gamma"],
    **kw
) -> Tensor:
    """
    Agent utility under CRRA:
      u(w) = w^(1-γ)/(1-γ),   γ ≠ 1
           = log(w)          when γ → 1
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_logistic_z(nsamples, dtype=a.dtype, device=a.device)
    w = wage(a + s * z, lam, mu, s_temp=s, a0=a0, w_min=w_min)

    if abs(gamma - 1.0) < 1e-8:
        u_w = torch.log(w)
    else:
        u_w = w.pow(1.0 - float(gamma)) / (1.0 - float(gamma))

    return u_w.mean() - 0.5 * torch.as_tensor(c, device=a.device, dtype=a.dtype) * (a ** 2)


def u1(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,  # unused but kept for consistency
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    **kw
) -> Tensor:
    """Principal utility: E[X - w(X)], X = a + sZ, Z ~ Logistic(0,1)."""
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_logistic_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s * z
    w = wage(x, lam, mu, s_temp=s, a0=a0, w_min=w_min)
    return (x - w).mean()


# -------- Projection --------
def project_t_box(t: Tensor, lam_bounds, mu_bounds) -> Tensor:
    """Clamp (λ, μ) into provided bounds."""
    tf = t.flatten().clone()
    tf[0] = tf[0].clamp(float(lam_bounds[0]), float(lam_bounds[1]))
    tf[1] = tf[1].clamp(float(mu_bounds[0]),  float(mu_bounds[1]))
    return tf.view_as(t)


# -------- Grid-search “theoretical” optimum --------
def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Optional[Dict[str, Number]] = None,
):
    """
    Returns (a*, t*) with shapes:
      a*: (1,)
      t*: (2,)  [λ, μ]
    For each (λ, μ) on a grid:
      - find a*(t) = argmax_a u2(a, t),
      - then pick (λ, μ) maximizing u1(a*(t), t).
    """
    sp = merge_with_defaults(setting_parameters)

    s, c, a0, w_min, gamma = float(sp["s"]), float(sp["c"]), float(sp["a0"]), float(sp["w_min"]), float(sp["gamma"])

    grid = dict(sp.get("grid", {}))
    a_cfg = grid.get("a", (a0 - 6 * s, a0 + 6 * s, 200))
    l_cfg = grid.get("lam", (w_min, w_min + 8.0, 100))
    m_cfg = grid.get("mu",  (0.0,   8.0,          100))
    star_nsamples = int(grid.get("nsamples", sp["nsamples"] * 8))

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    a_grid   = torch.linspace(float(a_cfg[0]), float(a_cfg[1]), int(a_cfg[2]), device=device, dtype=dtype)
    lam_grid = torch.linspace(float(l_cfg[0]), float(l_cfg[1]), int(l_cfg[2]), device=device, dtype=dtype)
    mu_grid  = torch.linspace(float(m_cfg[0]), float(m_cfg[1]), int(m_cfg[2]), device=device, dtype=dtype)
    z        = sample_logistic_z(star_nsamples, dtype=dtype, device=device)

    best_u1 = -float("inf")
    best_a  = torch.tensor([0.0], device=device, dtype=dtype)
    best_t  = torch.zeros(2, device=device, dtype=dtype)

    for lam in lam_grid:
        for mu in mu_grid:
            X = a_grid[None, :] + s * z[:, None]
            W = wage(X, lam, mu, s_temp=s, a0=a0, w_min=w_min)

            if abs(gamma - 1.0) < 1e-8:
                u2_vals = torch.log(W).mean(0)
            else:
                u2_vals = (W.pow(1.0 - gamma) / (1.0 - gamma)).mean(0)

            u2_vals = u2_vals - 0.5 * c * (a_grid ** 2)
            idx = int(torch.argmax(u2_vals))
            a_star_t = a_grid[idx]

            u1_val = (a_star_t + s * z - wage(a_star_t + s * z, lam, mu, s_temp=s, a0=a0, w_min=w_min)).mean().item()

            if u1_val > best_u1:
                best_u1 = u1_val
                best_a  = a_star_t.unsqueeze(0)
                best_t  = torch.stack([lam, mu])

    return best_a, best_t
