from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Iterable, Literal, Optional, Tuple

import torch


Tensor = torch.Tensor


@dataclass
class HVPConfig:
    wrt: Literal["state", "action"] = "action"
    create_graph: bool = True
    retain_graph: bool = True
    finite_diff_eps: float = 1e-3
    chunk_size: Optional[int] = None


def _ensure_requires_grad(x: Tensor) -> Tensor:
    if x.requires_grad:
        return x
    return x.clone().detach().requires_grad_(True)


def _select_vars(states: Tensor, actions: Tensor, wrt: Literal["state", "action"]) -> Tuple[Tensor, Tensor, Tensor]:
    if wrt == "action":
        a = _ensure_requires_grad(actions)
        s = states.detach()
        return s, a, a
    s = _ensure_requires_grad(states)
    a = actions.detach()
    return s, a, s


def hvp(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    vector: Tensor,
    wrt: Literal["state", "action"] = "action",
    create_graph: bool = True,
    retain_graph: bool = True,
) -> Tensor:
    s, a, var = _select_vars(states, actions, wrt)
    q = q_fn(s, a)
    g = torch.autograd.grad(q.sum(), var, create_graph=create_graph, retain_graph=True)[0]
    dot = (g * vector).sum()
    hv = torch.autograd.grad(dot, var, retain_graph=retain_graph)[0]
    return hv


def hvp_batched(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    vectors: Tensor,
    cfg: Optional[HVPConfig] = None,
) -> Tensor:
    if cfg is None:
        cfg = HVPConfig()
    return hvp(q_fn, states, actions, vectors, wrt=cfg.wrt, create_graph=cfg.create_graph, retain_graph=cfg.retain_graph)


def hvp_chunked(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    vectors: Tensor,
    cfg: Optional[HVPConfig] = None,
) -> Tensor:
    if cfg is None:
        cfg = HVPConfig()
    cs = cfg.chunk_size or vectors.shape[0]
    outs: list[Tensor] = []
    for i in range(0, vectors.shape[0], cs):
        v = vectors[i : i + cs]
        s = states[i : i + cs]
        a = actions[i : i + cs]
        outs.append(hvp(q_fn, s, a, v, wrt=cfg.wrt, create_graph=cfg.create_graph, retain_graph=cfg.retain_graph))
    return torch.cat(outs, dim=0)


def finite_diff_hvp(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    vector: Tensor,
    wrt: Literal["state", "action"] = "action",
    eps: float = 1e-3,
) -> Tensor:
    if wrt == "action":
        a = actions
        s = states
        f1 = q_fn(s, a + eps * vector)
        f2 = q_fn(s, a - eps * vector)
        f0 = q_fn(s, a)
    else:
        s = states
        a = actions
        f1 = q_fn(s + eps * vector, a)
        f2 = q_fn(s - eps * vector, a)
        f0 = q_fn(s, a)
                                                                            
                                                                                 
                                          
    vt_h_v = (f1 + f2 - 2.0 * f0) / (eps ** 2)
    denom = (vector.norm(dim=-1, keepdim=True) ** 2 + 1e-12)
    return (vt_h_v / denom) * vector


def random_directions(shape: Tuple[int, int], kind: Literal["normal", "rademacher", "sphere"] = "normal", device=None, dtype=None) -> Tensor:
    if kind == "normal":
        v = torch.randn(shape, device=device, dtype=dtype)
    elif kind == "rademacher":
        v = torch.randint(0, 2, shape, device=device)
        v = v.to(dtype or torch.float32)
        v = v * 2 - 1
    else:
        v = torch.randn(shape, device=device, dtype=dtype)
        v = v / (v.norm(dim=-1, keepdim=True) + 1e-12)
    return v


def hutchinson_trace(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    samples: int = 8,
    kind: Literal["normal", "rademacher"] = "rademacher",
) -> Tensor:
    s, a, var = _select_vars(states, actions, wrt)
    B, D = var.shape
    tr = torch.zeros(B, device=var.device, dtype=var.dtype)
    for _ in range(samples):
        if kind == "rademacher":
            z = torch.randint(0, 2, (B, D), device=var.device)
            z = z.to(var.dtype) * 2 - 1
        else:
            z = torch.randn(B, D, device=var.device, dtype=var.dtype)
        hz = hvp(q_fn, s, a, z, wrt=wrt, create_graph=False, retain_graph=True)
        tr = tr + (hz * z).sum(dim=-1)
    return tr / float(samples)


def diagonal_hutchinson(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    samples: int = 16,
) -> Tensor:
    s, a, var = _select_vars(states, actions, wrt)
    B, D = var.shape
    diag = torch.zeros(B, D, device=var.device, dtype=var.dtype)
    for _ in range(samples):
        z = torch.randint(0, 2, (B, D), device=var.device)
        z = z.to(var.dtype) * 2 - 1
        hz = hvp(q_fn, s, a, z, wrt=wrt, create_graph=False, retain_graph=True)
        diag = diag + hz * z
    return diag / float(samples)


def check_hvp_consistency(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    eps: float = 1e-3,
    trials: int = 3,
    kind: Literal["normal", "rademacher", "sphere"] = "normal",
) -> Tuple[float, float]:
    s, a, var = _select_vars(states, actions, wrt)
    B, D = var.shape
    max_rel_err = 0.0
    avg_rel_err = 0.0
    for _ in range(trials):
        v = random_directions((B, D), kind=kind, device=var.device, dtype=var.dtype)
        h_exact = hvp(q_fn, s, a, v, wrt=wrt, create_graph=False, retain_graph=True)
        h_fd = finite_diff_hvp(q_fn, s, a, v, wrt=wrt, eps=eps)
        num = (h_exact - h_fd).norm(dim=-1)
        den = h_exact.norm(dim=-1) + 1e-9
        rel = (num / den).mean().item()
        max_rel_err = max(max_rel_err, float(rel))
        avg_rel_err += rel
    avg_rel_err /= float(trials)
    return float(max_rel_err), float(avg_rel_err)


def jvp(
    f: Callable[[Tensor], Tensor],
    x: Tensor,
    v: Tensor,
    create_graph: bool = True,
) -> Tensor:
    x = _ensure_requires_grad(x)
    y = f(x)
    g = torch.autograd.grad(y.sum(), x, create_graph=create_graph, retain_graph=True)[0]
    return (g * v).sum(dim=-1, keepdim=True)


def vjp(
    f: Callable[[Tensor], Tensor],
    x: Tensor,
    v: Tensor,
    create_graph: bool = True,
) -> Tensor:
    x = _ensure_requires_grad(x)
    y = f(x)
    return torch.autograd.grad(y, x, grad_outputs=v, create_graph=create_graph, retain_graph=True)[0]


def hvp_from_jvp(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    vector: Tensor,
    wrt: Literal["state", "action"] = "action",
) -> Tensor:
    s, a, var = _select_vars(states, actions, wrt)
    q = q_fn(s, a)
                    
    g = torch.autograd.grad(q.sum(), var, create_graph=True, retain_graph=True)[0]
                                  
    dot = (g * vector).sum()
    hv = torch.autograd.grad(dot, var, retain_graph=True)[0]
    return hv


def hvp_suite(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    vector: Tensor,
    methods: Iterable[str] = ("autograd", "finite_diff", "jvp"),
    wrt: Literal["state", "action"] = "action",
    eps: float = 1e-3,
) -> dict:
    out: dict[str, Tensor] = {}
    if "autograd" in methods:
        out["autograd"] = hvp(q_fn, states, actions, vector, wrt=wrt, create_graph=False, retain_graph=True)
    if "finite_diff" in methods:
        out["finite_diff"] = finite_diff_hvp(q_fn, states, actions, vector, wrt=wrt, eps=eps)
    if "jvp" in methods:
        out["jvp"] = hvp_from_jvp(q_fn, states, actions, vector, wrt=wrt)
    return out


def _demo():
                                                   
    def q_fn(s: Tensor, a: Tensor) -> Tensor:
        return 0.5 * (a ** 2).sum(dim=-1) + (s * a).sum(dim=-1)

    B, S, A = 4, 3, 3
    s = torch.randn(B, S)
    a = torch.randn(B, A)
    v = torch.randn(B, A)

    h1 = hvp(q_fn, s, a, v, wrt="action", create_graph=False)
    h2 = finite_diff_hvp(q_fn, s, a, v, wrt="action")
    max_rel, avg_rel = check_hvp_consistency(q_fn, s, a, wrt="action")
    print("hvp ok", (h1 - h2).abs().max().item(), max_rel, avg_rel)


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
