from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, Iterable, Literal, Optional, Tuple

import torch

from .curvature_estimators import CurvatureOptions, robust_kappa


Tensor = torch.Tensor


@dataclass
class LipschitzConfig:
    eps: float = 1e-3
    trials: int = 4
    mode: Literal["action", "state", "joint"] = "action"
    steps: int = 8
    c: float = 1.0
    p: int = 2
    reduce: Literal["max", "mean"] = "max"
    chunk_size: Optional[int] = None


def _norm(x: Tensor, p: int = 2, dim: int = -1) -> Tensor:
    if p == 2:
        return torch.norm(x, dim=dim)
    if p == 1:
        return torch.sum(torch.abs(x), dim=dim)
    if p == 0:
        return (x != 0).sum(dim=dim).to(x.dtype)
    return torch.norm(x, p=p, dim=dim)


def _kappa(q_fn: Callable[[Tensor, Tensor], Tensor], s: Tensor, a: Tensor, cfg: LipschitzConfig) -> Tensor:
    return robust_kappa(q_fn, s, a, CurvatureOptions(mode=cfg.mode, c=cfg.c, steps=cfg.steps))


def empirical_lipschitz_kappa(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    c: float = 1.0,
    steps: int = 8,
    mode: Literal["action", "state", "joint"] = "action",
    eps: float = 1e-3,
    trials: int = 4,
) -> float:
    cfg = LipschitzConfig(eps=eps, trials=trials, mode=mode, steps=steps, c=c)
    with torch.no_grad():
        base = _kappa(q_fn, states, actions, cfg)
        ratios = []
        for _ in range(cfg.trials):
            ds = torch.randn_like(states) * cfg.eps
            da = torch.randn_like(actions) * cfg.eps
            pert = _kappa(q_fn, states + ds, actions + da, cfg)
            num = (pert - base).abs().max().item()
            den = torch.max(_norm(ds, p=2, dim=-1), _norm(da, p=2, dim=-1)).max().item() + 1e-12
            ratios.append(num / den)
        return float(max(ratios))


def empirical_lipschitz_kappa_batch(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    cfg: Optional[LipschitzConfig] = None,
) -> Dict[str, float]:
    if cfg is None:
        cfg = LipschitzConfig()
    with torch.no_grad():
        base = _kappa(q_fn, states, actions, cfg)
        stats = {"L/max": 0.0, "L/mean": 0.0}
        vals = []
        for _ in range(cfg.trials):
            ds = torch.randn_like(states) * cfg.eps
            da = torch.randn_like(actions) * cfg.eps
            pert = _kappa(q_fn, states + ds, actions + da, cfg)
            num = (pert - base).abs()       
            den = torch.max(_norm(ds, p=cfg.p, dim=-1), _norm(da, p=cfg.p, dim=-1)) + 1e-12
            L_i = (num / den).cpu().tolist()
            vals.extend(L_i)
        if vals:
            stats["L/max"] = float(max(vals))
            stats["L/mean"] = float(sum(vals) / len(vals))
        return stats


def local_kappa_stats(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    cfg: Optional[LipschitzConfig] = None,
) -> Dict[str, float]:
    if cfg is None:
        cfg = LipschitzConfig()
    with torch.no_grad():
        k = _kappa(q_fn, states, actions, cfg)
        out = {
            "kappa/mean": float(k.mean().item()),
            "kappa/std": float(k.std().item()) if k.numel() > 1 else 0.0,
            "kappa/max": float(k.max().item()),
            "kappa/min": float(k.min().item()),
        }
        return out


def gradient_lipschitz_estimate(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    eps: float = 1e-3,
    trials: int = 4,
) -> float:
    vals = []
    for _ in range(trials):
        if wrt == "action":
            a1 = actions.clone().detach().requires_grad_(True)
            s1 = states.detach()
            v1 = f(s1, a1)
            g1 = torch.autograd.grad(v1.sum(), a1, retain_graph=False, create_graph=False)[0]
            a2 = (actions + torch.randn_like(actions) * eps).clone().detach().requires_grad_(True)
            v2 = f(states.detach(), a2)
            g2 = torch.autograd.grad(v2.sum(), a2, retain_graph=False, create_graph=False)[0]
            num = (g2 - g1).norm(dim=-1).max().item()
            den = (a2 - a1).norm(dim=-1).max().item() + 1e-12
            vals.append(num / den)
        else:
            s1 = states.clone().detach().requires_grad_(True)
            v1 = f(s1, actions.detach())
            g1 = torch.autograd.grad(v1.sum(), s1, retain_graph=False, create_graph=False)[0]
            s2 = (states + torch.randn_like(states) * eps).clone().detach().requires_grad_(True)
            v2 = f(s2, actions.detach())
            g2 = torch.autograd.grad(v2.sum(), s2, retain_graph=False, create_graph=False)[0]
            num = (g2 - g1).norm(dim=-1).max().item()
            den = (s2 - s1).norm(dim=-1).max().item() + 1e-12
            vals.append(num / den)
    return float(max(vals) if vals else 0.0)


def estimate_tau(q: Tensor, q_tgt: Tensor) -> float:
    return float((q - q_tgt).abs().max().item())


def _demo():
    def q_fn(s: Tensor, a: Tensor) -> Tensor:
        return 0.5 * (a ** 2).sum(dim=-1) + (s * a).sum(dim=-1)

    B, S, A = 8, 3, 3
    s = torch.randn(B, S)
    a = torch.randn(B, A)
    print(empirical_lipschitz_kappa(q_fn, s, a, c=1.0, steps=6, mode="action", eps=1e-3, trials=3))


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
