import torch
from torch import Tensor
from typing import Dict, Tuple, Union, Optional
Number = Union[float, int]

def sample_logistic_z(nsamples: int, dtype: torch.dtype, device: torch.device) -> Tensor:
    u = torch.rand(nsamples, device=device, dtype=dtype).clamp_(1e-7, 1-1e-7)
    return torch.log(u) - torch.log1p(-u)

def _unpack_lambda_mu(t: Tensor):
    t_flat = t.flatten()
    if t_flat.numel() < 2: raise ValueError("t must contain at least (lambda, mu).")
    return t_flat[0], t_flat[1]

def wage(x: Tensor, lam: Tensor, mu: Tensor, s_temp: float, a0: float, w_min: float) -> Tensor:
    return torch.clamp(lam + mu*torch.sigmoid((x - a0)/float(s_temp)), min=float(w_min))

def u2(a: Tensor, t: Tensor, *, s: float, c: float, a0: float=0.0, w_min: float=1e-3,
       nsamples: int=1024, z: Optional[Tensor]=None, rho: float=1.0, **kw) -> Tensor:
    """CARA: E[-exp(-ρ w)] - 0.5 c a^2,  Z ~ Logistic."""
    lam, mu = _unpack_lambda_mu(t)
    if z is None: z = sample_logistic_z(nsamples, dtype=a.dtype, device=a.device)
    w = wage(a + s*z, lam, mu, s_temp=s, a0=a0, w_min=w_min)
    u_w = -torch.exp(-torch.as_tensor(rho, device=w.device, dtype=w.dtype) * w)
    return u_w.mean() - 0.5*torch.as_tensor(c, device=a.device, dtype=a.dtype)*(a**2)

def u1(a: Tensor, t: Tensor, *, s: float, c: float, a0: float=0.0, w_min: float=1e-3,
       nsamples: int=1024, z: Optional[Tensor]=None, **kw) -> Tensor:
    lam, mu = _unpack_lambda_mu(t)
    if z is None: z = sample_logistic_z(nsamples, dtype=a.dtype, device=a.device)
    x = a + s*z
    return (x - wage(x, lam, mu, s_temp=s, a0=a0, w_min=w_min)).mean()

def project_t_box(t: Tensor, lam_bounds, mu_bounds) -> Tensor:
    tf = t.flatten().clone()
    tf[0] = tf[0].clamp(float(lam_bounds[0]), float(lam_bounds[1]))
    tf[1] = tf[1].clamp(float(mu_bounds[0]),  float(mu_bounds[1]))
    return tf.view_as(t)

def get_theoretical_optimum(t=None, setting_parameters: Dict[str, Number]=None):
    s = float(setting_parameters.get("s",1.0)); c = float(setting_parameters.get("c",1.0))
    a0 = float(setting_parameters.get("a0",0.0)); w_min = float(setting_parameters.get("w_min",1e-3))
    grid = dict(setting_parameters.get("grid",{}))
    a_cfg = grid.get("a",(a0-4*s, a0+4*s, 201)); l_cfg=grid.get("lam",(w_min,w_min+5.0,61)); m_cfg=grid.get("mu",(0.0,5.0,61))
    star_nsamples = int(grid.get("nsamples",8192))
    device = t.device if isinstance(t, torch.Tensor) else torch.device("cpu")
    dtype  = t.dtype  if isinstance(t, torch.Tensor) else torch.float32
    a_grid = torch.linspace(float(a_cfg[0]), float(a_cfg[1]), int(a_cfg[2]), device=device, dtype=dtype)
    lam_grid = torch.linspace(float(l_cfg[0]), float(l_cfg[1]), int(l_cfg[2]), device=device, dtype=dtype)
    mu_grid  = torch.linspace(float(m_cfg[0]), float(m_cfg[1]), int(m_cfg[2]), device=device, dtype=dtype)
    z = sample_logistic_z(star_nsamples, dtype=dtype, device=device)
    best_u1=-1e30; best_a=torch.tensor([0.0],device=device,dtype=dtype); best_t=torch.zeros(2,device=device,dtype=dtype)
    rho = float(setting_parameters.get("rho",1.0))
    for lam in lam_grid:
        for mu in mu_grid:
            X = a_grid[None,:] + s*z[:,None]
            W = wage(X, lam, mu, s_temp=s, a0=a0, w_min=w_min)
            u2_vals = (-torch.exp(-rho*W)).mean(0) - 0.5*c*(a_grid**2)
            idx = int(torch.argmax(u2_vals)); a_star_t = a_grid[idx]
            u1_val = (a_star_t + s*z - wage(a_star_t + s*z, lam, mu, s_temp=s, a0=a0, w_min=w_min)).mean().item()
            if u1_val > best_u1: best_u1, best_a, best_t = u1_val, torch.tensor([a_star_t],device=device,dtype=dtype), torch.stack([lam,mu])
    return best_a, best_t
