import math
from typing import Dict, Tuple, Union, Optional
import torch
from torch import Tensor
from torch.distributions import Distribution, constraints
Number = Union[float, int]

DEFAULT_T_INIT = torch.tensor([1.0, 1.0], dtype=torch.float32)  # λ=1, μ=1
DEFAULT_A_INIT = torch.tensor([0.0], dtype=torch.float32)       # agent starts at 0 effort
DEFAULT_PARAMS = dict(
    s=1.0,
    c=0.25,
    a0=0.0,
    w_min=0.25,
    nsamples=1024,
)

_TLEN = int(DEFAULT_T_INIT.numel())  # 2
T_TRAIN_MASK  = torch.ones(_TLEN, dtype=torch.float32)  # [1., 1.] — train λ and μ
T_METRIC_MASK = torch.ones(_TLEN, dtype=torch.float32)  # [1., 1.] — include both in distances
T_LABELS = ["lambda", "mu"]
def merge_with_defaults(params: Optional[Dict[str, Number]]) -> Dict[str, Number]:
    """Merge provided params with DEFAULT_PARAMS (DEFAULTS take priority if missing)."""
    out = dict(DEFAULT_PARAMS)
    if params:
        out.update(params)
    return out

class Logistic(Distribution):
    arg_constraints = {}
    support = constraints.real
    def __init__(self, loc=0.0, scale=1.0, validate_args=None):
        self.loc = torch.as_tensor(loc)
        self.scale = torch.as_tensor(scale)
        batch_shape = torch.broadcast_shapes(self.loc.shape, self.scale.shape)
        super().__init__(batch_shape, validate_args=validate_args)
    def rsample(self, sample_shape=torch.Size()):
        u = torch.rand(sample_shape + self.batch_shape,
                       device=self.loc.device, dtype=self.loc.dtype)
        u = u.clamp_(1e-7, 1 - 1e-7)
        return self.loc + self.scale * torch.log(u / (1 - u))
    def log_prob(self, value):
        z = (value - self.loc) / self.scale
        return -(z + 2 * torch.nn.functional.softplus(-z) + torch.log(self.scale))
        

def sample_logistic_z(nsamples: int, dtype: torch.dtype, device: torch.device) -> Tensor:
    """Draw Logistic(0,1) samples using inverse CDF."""
    u = torch.rand(nsamples, device=device, dtype=dtype)
    return torch.log(u) - torch.log1p(-u)
def _unpack_lambda_mu(t: Tensor) -> Tuple[Tensor, Tensor]:
    """Here t = (lambda, mu)."""
    t_flat = t.flatten()
    if t_flat.numel() < 2:
        raise ValueError("t must contain at least (lambda, mu).")
    return t_flat[0], t_flat[1]
def wage(x: Tensor, lam: Tensor, mu: Tensor, s: float, a0: float, w_min: float) -> Tensor:
    """
    Contract: w(x) = lam + mu * sigmoid((x - a0)/s),
    with positivity floor to avoid log(0).
    """
    w = lam + mu * torch.sigmoid((x - a0) / float(s))
    return torch.clamp(w, min=float(w_min))

def get_t_bounds(setting_parameters: Dict[str, Number]) -> Tuple[Tuple[float, float], Tuple[float, float]]:
    """Derive a simple, safe box for (lambda, mu) from settings."""
    w_min = float(setting_parameters["w_min"])
    lam_bounds = (w_min, w_min + 8.0)
    mu_bounds  = (0.0, 8.0)
    return lam_bounds, mu_bounds
def project_t_box(t: Tensor, lam_bounds: Tuple[float, float], mu_bounds: Tuple[float, float]) -> Tensor:
    """Clamp (lambda, mu) into safe box."""
    t_flat = t.flatten()
    lam, mu = t_flat[0], t_flat[1]
    lb_lam, ub_lam = lam_bounds
    lb_mu,  ub_mu  = mu_bounds
    lam = lam.clamp(min=float(lb_lam), max=float(ub_lam))
    mu  = mu.clamp(min=float(lb_mu),  max=float(ub_mu))
    out = t_flat.clone()
    out[0] = lam
    out[1] = mu
    return out.view_as(t)
def project_t_box_default(
    t: Tensor,
    setting_parameters: Optional[Dict[str, Number]] = None,
) -> Tensor:
    sp = merge_with_defaults(setting_parameters or {})
    lam_bounds, mu_bounds = get_t_bounds(sp)
    return project_t_box(t, lam_bounds, mu_bounds)
def project_a_box(a: Tensor, a_bounds: Optional[Tuple[float, float]] = None,
                  setting_parameters: Optional[Dict[str, Number]] = None) -> Tensor:
    """
    Clamp a to a box. If a_bounds not given, use [a0-6s, a0+6s].
    """
    sp = merge_with_defaults(setting_parameters or {})
    if a_bounds is None:
        s = float(sp["s"])
        a0 = float(sp["a0"])
        a_bounds = (a0 - 6.0 * s, a0 + 6.0 * s)
    lb, ub = a_bounds
    return a.clamp(min=float(lb), max=float(ub))

def u2(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,
    a0: float = 0.0,
    w_min: float = 1e-6,
    nsamples: int = 1024,
    z: Optional[Tensor] = None,
    **kwargs
) -> Tensor:
    """
    Agent's expected utility:
        E[ log w_t(X) | a ] - cost(a), where X = a + s Z, Z ~ Logistic(0,1).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        dist = Logistic(loc=torch.zeros((), device=a.device), scale=1.0)
        z = dist.rsample((nsamples,))
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    util = torch.log(w)
    return util.mean() - 0.5 * torch.as_tensor(c, device=a.device, dtype=a.dtype) * (a**2)
def u1(
    a: Tensor,
    t: Tensor,
    *,
    s: float,
    c: float,
    a0: float = 0.0,
    w_min: float = 1e-6,
    nsamples: int = 1024,
    z: Optional[Tensor] = None,
    **kwargs
) -> Tensor:
    """
    Principal's expected utility:
        E[ X - w_t(X) | a ] with X = a + s Z, Z ~ Logistic(0,1).
    """
    lam, mu = _unpack_lambda_mu(t)
    if z is None:
        dist = Logistic(loc=torch.zeros((), device=a.device), scale=1.0)
        z = dist.rsample((nsamples,))
    x = a + s * z
    w = wage(x, lam, mu, s=s, a0=a0, w_min=w_min)
    return (x - w).mean()


def get_theoretical_optimum(
    t: Optional[torch.Tensor] = None,
    setting_parameters: Dict[str, Number] = None,
) -> Tuple[Tensor, Tensor]:
    """
    Approximate (a*, t*) by nested grid search:
      - maximize u2 over a for each (lambda, mu),
      - then pick (lambda, mu) maximizing principal's u1 at that a*.
    """
    if setting_parameters is None:
        raise ValueError("`setting_parameters` dict must be provided.")
    sp = merge_with_defaults(setting_parameters)
    s = float(sp["s"]); c = float(sp["c"]); a0 = float(sp["a0"])
    w_min = float(sp["w_min"]); ns = int(sp["nsamples"])
    grid: Dict[str, Number] = dict(sp.get("grid", {}))
    amin, amax, n_a = grid.get("a", (a0 - 6.0 * s, a0 + 6.0 * s, 200))
    if "lam" in grid:
        lmin, lmax, n_l = grid["lam"]
    else:
        lam_bounds, _ = get_t_bounds(sp)
        lmin, lmax, n_l = lam_bounds[0], lam_bounds[1], 100
    if "mu" in grid:
        mmin, mmax, n_m = grid["mu"]
    else:
        _, mu_bounds = get_t_bounds(sp)
        mmin, mmax, n_m = mu_bounds[0], mu_bounds[1], 100
    device = t.device if isinstance(t, torch.Tensor) else torch.device("cpu")
    dtype  = t.dtype  if isinstance(t, torch.Tensor) else torch.float32
    a_grid   = torch.linspace(float(amin), float(amax), int(n_a), device=device, dtype=dtype)
    lam_grid = torch.linspace(float(lmin), float(lmax), int(n_l), device=device, dtype=dtype)
    mu_grid  = torch.linspace(float(mmin), float(mmax), int(n_m), device=device, dtype=dtype)
    best_u1 = -float("inf")
    best_a  = torch.tensor([0.0], device=device, dtype=dtype)
    best_t  = torch.zeros(2, device=device, dtype=dtype)
    z = sample_logistic_z(ns, dtype=dtype, device=device)
    for lam in lam_grid:
        for mu in mu_grid:
            X = a_grid[None, :] + s * z[:, None]          # [ns, n_a]
            W = wage(X, lam, mu, s=s, a0=a0, w_min=w_min) # [ns, n_a]
            u2_vals = torch.log(W).mean(0) - 0.5 * c * (a_grid**2)
            idx = int(torch.argmax(u2_vals).item())
            a_star_t = a_grid[idx]
            X_outer = a_star_t + s * z
            W_outer = wage(X_outer, lam, mu, s=s, a0=a0, w_min=w_min)
            u1_val = (X_outer - W_outer).mean().item()
            if u1_val > best_u1:
                best_u1 = u1_val
                best_a  = a_star_t.reshape(1)               # (1,)
                best_t  = torch.stack([lam, mu])            # (2,)
    return best_a, best_t