from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Iterable, Literal, Optional, Tuple

import torch


Tensor = torch.Tensor


def _validate_sigma(sigma: float) -> float:
    if sigma <= 0:
        raise ValueError("sigma must be positive")
    return float(sigma)


def gaussian_kernel_noise(shape: Tuple[int, ...], sigma: float, device=None, dtype=None) -> Tensor:
    sigma = _validate_sigma(sigma)
    return torch.randn(shape, device=device, dtype=dtype) * sigma


def laplace_kernel_noise(shape: Tuple[int, ...], b: float, device=None, dtype=None) -> Tensor:
    if b <= 0:
        raise ValueError("b must be positive")
                                            
    u = torch.rand(shape, device=device, dtype=dtype) - 0.5
    return -b * torch.sign(u) * torch.log1p(-2 * torch.abs(u))


def uniform_ball_noise(shape: Tuple[int, ...], radius: float, device=None, dtype=None) -> Tensor:
    if radius <= 0:
        raise ValueError("radius must be positive")
    v = torch.randn(shape, device=device, dtype=dtype)
    v = v / (v.norm(dim=-1, keepdim=True) + 1e-12)
    r = torch.rand(shape[:-1] + (1,), device=device, dtype=dtype) ** (1.0 / shape[-1])
    return v * (r * radius)


@dataclass
class MollifyConfig:
    sigma: float = 1e-3
    samples: int = 8
    wrt: Literal["state", "action", "both"] = "both"
    kernel: Literal["gaussian", "laplace", "uniform"] = "gaussian"
    kernel_param: float = 1.0
    chunk_size: Optional[int] = None
    antithetic: bool = True
    center: Literal["zero", "mean"] = "zero"


def _sample_noise(x: Tensor, cfg: MollifyConfig) -> Tensor:
    if cfg.kernel == "gaussian":
        return gaussian_kernel_noise(x.shape, cfg.sigma, device=x.device, dtype=x.dtype)
    if cfg.kernel == "laplace":
        return laplace_kernel_noise(x.shape, cfg.kernel_param, device=x.device, dtype=x.dtype)
    if cfg.kernel == "uniform":
        return uniform_ball_noise(x.shape, cfg.kernel_param, device=x.device, dtype=x.dtype)
    raise ValueError(f"unknown kernel {cfg.kernel}")


def _apply_center(x: Tensor, ref: Tensor, mode: str) -> Tensor:
    if mode == "zero":
        return x
    if mode == "mean":
        return x - x.mean(dim=0, keepdim=True) + ref
    return x


def gaussian_mollify(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    sigma: float = 1e-3,
    samples: int = 8,
    wrt: Literal["state", "action", "both"] = "both",
) -> Tensor:
    cfg = MollifyConfig(sigma=sigma, samples=samples, wrt=wrt)
    return mollify(f, states, actions, cfg)


def mollify(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    cfg: Optional[MollifyConfig] = None,
) -> Tensor:
    if cfg is None:
        cfg = MollifyConfig()
    s0, a0 = states, actions
    out = 0.0
                                         
    m = cfg.samples // 2 if cfg.antithetic else cfg.samples
    for i in range(m):
        ds = _sample_noise(states, cfg) if cfg.wrt in ("state", "both") else torch.zeros_like(states)
        da = _sample_noise(actions, cfg) if cfg.wrt in ("action", "both") else torch.zeros_like(actions)
        s = _apply_center(s0 + ds, s0, cfg.center)
        a = _apply_center(a0 + da, a0, cfg.center)
        out = out + f(s, a)
        if cfg.antithetic:
            s2 = _apply_center(s0 - ds, s0, cfg.center)
            a2 = _apply_center(a0 - da, a0, cfg.center)
            out = out + f(s2, a2)
    denom = float(cfg.samples if not cfg.antithetic else 2 * m)
    return out / denom


def mollify_chunked(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    cfg: Optional[MollifyConfig] = None,
) -> Tensor:
    if cfg is None or cfg.chunk_size is None:
        return mollify(f, states, actions, cfg)
    cs = cfg.chunk_size
    outs: list[Tensor] = []
    for i in range(0, states.shape[0], cs):
        s = states[i : i + cs]
        a = actions[i : i + cs]
        outs.append(mollify(f, s, a, cfg))
    return torch.cat(outs, dim=0)


def gradient_mollify(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    cfg: Optional[MollifyConfig] = None,
    wrt: Literal["state", "action"] = "action",
) -> Tensor:
                                            
    if cfg is None:
        cfg = MollifyConfig()
    grads = 0.0
    m = cfg.samples // 2 if cfg.antithetic else cfg.samples
    for i in range(m):
        ds = _sample_noise(states, cfg) if cfg.wrt in ("state", "both") else torch.zeros_like(states)
        da = _sample_noise(actions, cfg) if cfg.wrt in ("action", "both") else torch.zeros_like(actions)
        s = _apply_center(states + ds, states, cfg.center)
        a = _apply_center(actions + da, actions, cfg.center)
        s = s.detach(); a = a.detach()
        if wrt == "action":
            a = a.clone().detach().requires_grad_(True)
        else:
            s = s.clone().detach().requires_grad_(True)
        val = f(s, a)
        g = torch.autograd.grad(val.sum(), a if wrt == "action" else s, retain_graph=False, create_graph=False)[0]
        grads = grads + g
        if cfg.antithetic:
            s2 = _apply_center(states - ds, states, cfg.center)
            a2 = _apply_center(actions - da, actions, cfg.center)
            s2 = s2.detach(); a2 = a2.detach()
            if wrt == "action":
                a2 = a2.clone().detach().requires_grad_(True)
            else:
                s2 = s2.clone().detach().requires_grad_(True)
            val2 = f(s2, a2)
            g2 = torch.autograd.grad(val2.sum(), a2 if wrt == "action" else s2, retain_graph=False, create_graph=False)[0]
            grads = grads + g2
    denom = float(cfg.samples if not cfg.antithetic else 2 * m)
    return grads / denom


def mollify_and_grad(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    cfg: Optional[MollifyConfig] = None,
    wrt: Literal["state", "action"] = "action",
) -> Tuple[Tensor, Tensor]:
    if cfg is None:
        cfg = MollifyConfig()
    s0, a0 = states, actions
    val = 0.0
    grad_acc = 0.0
    m = cfg.samples // 2 if cfg.antithetic else cfg.samples
    for _ in range(m):
        ds = _sample_noise(states, cfg) if cfg.wrt in ("state", "both") else torch.zeros_like(states)
        da = _sample_noise(actions, cfg) if cfg.wrt in ("action", "both") else torch.zeros_like(actions)
        for sign in (+1.0, -1.0) if cfg.antithetic else (1.0,):
            s = _apply_center(s0 + sign * ds, s0, cfg.center)
            a = _apply_center(a0 + sign * da, a0, cfg.center)
            s = s.detach(); a = a.detach()
            if wrt == "action":
                a = a.clone().detach().requires_grad_(True)
                v = f(s, a)
                g = torch.autograd.grad(v.sum(), a, retain_graph=False, create_graph=False)[0]
            else:
                s = s.clone().detach().requires_grad_(True)
                v = f(s, a)
                g = torch.autograd.grad(v.sum(), s, retain_graph=False, create_graph=False)[0]
            val = val + v
            grad_acc = grad_acc + g
    denom = float(cfg.samples if not cfg.antithetic else 2 * m)
    return val / denom, grad_acc / denom


def _demo():
    def f(s: Tensor, a: Tensor) -> Tensor:
        return 0.5 * (a ** 2).sum(dim=-1) + (s * a).sum(dim=-1)

    B, S, A = 8, 3, 3
    s = torch.randn(B, S)
    a = torch.randn(B, A)
    cfg = MollifyConfig(sigma=1e-2, samples=16, wrt="both", kernel="gaussian", antithetic=True)
    val = mollify(f, s, a, cfg)
    vg, gg = mollify_and_grad(f, s, a, cfg, wrt="action")
    print(val.shape, vg.shape, gg.shape)


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
