import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional

Number = Union[float, int]


# ======================= Defaults =======================
DEFAULT_T_INIT = torch.tensor([1.0, 1.0], dtype=torch.float32)   # λ=1, μ=1
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)        # agent starts at 0 effort
DEFAULT_PARAMS = dict(
    s=1.0,
    c=1.0,
    a0=0.0,
    w_min=1e-3,
    nsamples=1024,
    alpha=1.0,
    beta=1.0,
    gamma=1.1,
)


def merge_with_defaults(params: Optional[Dict[str, Number]]) -> Dict[str, Number]:
    """Merge provided params with DEFAULT_PARAMS (defaults take priority if missing)."""
    out = dict(DEFAULT_PARAMS)
    if params:
        out.update(params)
    return out


# -------- Sampling --------
def sample_normal_z(nsamples: int, dtype: torch.dtype, device: torch.device) -> Tensor:
    """Draw Z ~ N(0,1)."""
    return torch.randn(nsamples, device=device, dtype=dtype)


# -------- Contract helpers --------
def _unpack_lambda_mu(t: Tensor) -> Tuple[Tensor, Tensor]:
    """Unpack (lambda, mu) from tensor t."""
    t_flat = t.reshape(-1)
    if t_flat.numel() < 2:
        raise ValueError("t must contain at least (lambda, mu).")
    return t_flat[0], t_flat[1]


def wage(x: Tensor, lam: Tensor, mu: Tensor, *, s: float, a0: float, w_min: float) -> Tensor:
    """
    w(x) = lam + mu * sigmoid((x - a0)/s), floored at w_min for numerical safety.
    """
    w = lam + mu * torch.sigmoid((x - a0) / float(s))
    return torch.clamp(w, min=float(w_min))


# -------- Preferences --------
def hara_utility(w: Tensor, *, alpha: float, beta: float, gamma: float) -> Tensor:
    """
    HARA: u(w) = (α + βw)^(1 - γ) / (1 - γ),  γ ≠ 1
                = log(α + βw) when γ → 1.
    Requires α + βw > 0; clamp from below for stability.
    """
    a = torch.as_tensor(alpha, device=w.device, dtype=w.dtype)
    b = torch.as_tensor(beta,  device=w.device, dtype=w.dtype)
    g = torch.as_tensor(gamma, device=w.device, dtype=w.dtype)

    base = torch.clamp(a + b * w, min=1e-8)
    one = torch.tensor(1.0, device=w.device, dtype=w.dtype)
    eps = torch.tensor(1e-8, device=w.device, dtype=w.dtype)

    is_log_case = (g - one).abs() < eps
    if bool(is_log_case):
        return torch.log(base)
    else:
        exp = one - g
        return base.pow(exp) / exp


# -------- Utilities --------
def u2(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    alpha: float = DEFAULT_PARAMS["alpha"],
    beta: float  = DEFAULT_PARAMS["beta"],
    gamma: float = DEFAULT_PARAMS["gamma"],
    **kw
) -> Tensor:
    """
    Agent (HARA) utility:
      E[ u(w) ] - 0.5 * c * a^2,
    where u(w) is HARA with (α,β,γ), and X = a + s Z, Z ~ N(0,1).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_normal_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    u_w = hara_utility(w, alpha=alpha, beta=beta, gamma=gamma)
    return u_w.mean() - 0.5 * torch.as_tensor(c, device=a.device, dtype=a.dtype) * (a ** 2)


def u1(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,  # unused but kept for consistent signature
    a0: float = DEFAULT_PARAMS["a0"],
    w_min: float = DEFAULT_PARAMS["w_min"],
    nsamples: int = DEFAULT_PARAMS["nsamples"],
    z: Optional[Tensor] = None,
    **kw
) -> Tensor:
    """
    Principal utility:
      E[ X - w_t(X) ] with X = a + s Z, Z ~ N(0,1).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        z = sample_normal_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    return (x - w).mean()


# -------- Projection (box) --------
def project_t_box(t: Tensor, lam_bounds: Tuple[float, float], mu_bounds: Tuple[float, float]) -> Tensor:
    """Clamp (λ, μ) into provided bounds."""
    t_flat = t.reshape(-1).clone()
    t_flat[0] = t_flat[0].clamp(float(lam_bounds[0]), float(lam_bounds[1]))
    t_flat[1] = t_flat[1].clamp(float(mu_bounds[0]),  float(mu_bounds[1]))
    return t_flat.view_as(t)


# -------- Grid-search “theoretical” optimum (approximate) --------
def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Returns (a_star, t_star) with shapes:
      a_star: (1,)
      t_star: (2,)  [λ, μ]
    For each (λ, μ) on a grid, pick a*(t) = argmax_a u2(a, t) on a coarse grid,
    then choose (λ, μ) maximizing u1(a*(t), t).
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")

    sp = merge_with_defaults(setting_parameters)

    s, c, a0 = float(sp["s"]), float(sp["c"]), float(sp["a0"])
    w_min = float(sp["w_min"])
    alpha, beta, gamma = float(sp["alpha"]), float(sp["beta"]), float(sp["gamma"])

    grid = dict(sp.get("grid", {}))
    a_cfg = grid.get("a",   (a0 - 6.0 * s, a0 + 6.0 * s, 200))
    l_cfg = grid.get("lam", (w_min,        w_min + 8.0,   100))
    m_cfg = grid.get("mu",  (0.0,          8.0,           100))
    star_nsamples = int(grid.get("nsamples", sp["nsamples"] * 8))

    if isinstance(t, torch.Tensor):
        device, dtype = t.device, t.dtype
    else:
        device, dtype = torch.device("cpu"), torch.float32

    a_grid  = torch.linspace(float(a_cfg[0]), float(a_cfg[1]), int(a_cfg[2]), device=device, dtype=dtype)
    lam_g   = torch.linspace(float(l_cfg[0]), float(l_cfg[1]), int(l_cfg[2]), device=device, dtype=dtype)
    mu_g    = torch.linspace(float(m_cfg[0]), float(m_cfg[1]), int(m_cfg[2]), device=device, dtype=dtype)
    z       = sample_normal_z(star_nsamples, dtype=dtype, device=device)

    best_u1 = -float("inf")
    best_a  = torch.zeros(1, device=device, dtype=dtype)
    best_t  = torch.zeros(2, device=device, dtype=dtype)

    for lam in lam_g:
        for mu in mu_g:
            X = a_grid[None, :] + s * z[:, None]  # [ns, n_a]
            W = wage(X, lam, mu, s=s, a0=a0, w_min=w_min)

            U2 = hara_utility(W, alpha=alpha, beta=beta, gamma=gamma).mean(dim=0) - 0.5 * c * (a_grid ** 2)
            idx = int(torch.argmax(U2).item())
            a_star_t = a_grid[idx]

            # Outer objective at a_star_t
            X_outer = a_star_t + s * z
            W_outer = wage(X_outer, lam, mu, s=s, a0=a0, w_min=w_min)
            u1_val = (X_outer - W_outer).mean().item()

            if u1_val > best_u1:
                best_u1 = u1_val
                best_a  = a_star_t.unsqueeze(0)
                best_t  = torch.stack([lam, mu])

    return best_a, best_t
