from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, Iterable, Literal, Optional, Tuple

import torch

from .hvp import hvp, hutchinson_trace
from .lanczos import batched_lanczos_min, batched_power_min, LanczosOptions, multi_start_lanczos


Tensor = torch.Tensor


@dataclass
class CurvatureOptions:
    mode: Literal["action", "state", "joint"] = "action"
    c: float = 1.0
    steps: int = 8
    weights: Tuple[float, float] = (1.0, 1.0)
    method: Literal["lanczos", "power"] = "power"
    starts: int = 1
    clamp_min: float = 0.0
    clamp_max: Optional[float] = None
    use_trace: bool = False
    trace_samples: int = 8
    reduce: Literal["sum", "mean"] = "sum"
    detach: bool = True


def grad_norm(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
) -> Tensor:
    if wrt == "action":
        actions = actions.clone().detach().requires_grad_(True)
        q = q_fn(states.detach(), actions)
        g = torch.autograd.grad(q.sum(), actions, retain_graph=False, create_graph=False)[0]
    else:
        states = states.clone().detach().requires_grad_(True)
        q = q_fn(states, actions.detach())
        g = torch.autograd.grad(q.sum(), states, retain_graph=False, create_graph=False)[0]
    return torch.norm(g, dim=-1)


def _min_eig(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    steps: int = 8,
    wrt: Literal["state", "action"] = "action",
    method: Literal["lanczos", "power"] = "power",
    starts: int = 1,
) -> Tensor:
    if method == "lanczos":
        if starts <= 1:
            return batched_lanczos_min(q_fn, states, actions, steps=steps, wrt=wrt)
        else:
            return multi_start_lanczos(q_fn, states, actions, starts=starts, opts=LanczosOptions(steps=steps, wrt=wrt))
    return batched_power_min(q_fn, states, actions, iters=steps, wrt=wrt)


def min_eig(q_fn: Callable[[Tensor, Tensor], Tensor], states: Tensor, actions: Tensor, steps: int = 8, wrt: Literal["state", "action"] = "action") -> Tensor:
    return _min_eig(q_fn, states, actions, steps=steps, wrt=wrt, method="power", starts=1)


def curvature_components(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    opts: Optional[CurvatureOptions] = None,
) -> Dict[str, Tensor]:
    if opts is None:
        opts = CurvatureOptions()
    out: Dict[str, Tensor] = {}
    if opts.mode in ("action", "joint"):
        g_a = grad_norm(q_fn, states, actions, wrt="action")
        lam_a = _min_eig(q_fn, states, actions, steps=opts.steps, wrt="action", method=opts.method, starts=opts.starts)
        c_a = torch.clamp(-lam_a, min=0.0)
        out.update({"g_a": g_a, "c_a": c_a, "lam_a": lam_a})
    if opts.mode in ("state", "joint"):
        g_s = grad_norm(q_fn, states, actions, wrt="state")
        lam_s = _min_eig(q_fn, states, actions, steps=opts.steps, wrt="state", method=opts.method, starts=opts.starts)
        c_s = torch.clamp(-lam_s, min=0.0)
        out.update({"g_s": g_s, "c_s": c_s, "lam_s": lam_s})
    if opts.use_trace:
        tr_a = hutchinson_trace(q_fn, states, actions, wrt="action", samples=opts.trace_samples)
        tr_s = hutchinson_trace(q_fn, states, actions, wrt="state", samples=opts.trace_samples)
        out.update({"tr_a": tr_a, "tr_s": tr_s})
    return out


def kappa(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    c: float = 1.0,
    steps: int = 8,
    mode: Literal["action", "state", "joint"] = "action",
    weights: Tuple[float, float] = (1.0, 1.0),
) -> Tensor:
    opts = CurvatureOptions(mode=mode, c=c, steps=steps, weights=weights)
    return robust_kappa(q_fn, states, actions, opts)


def robust_kappa(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    opts: Optional[CurvatureOptions] = None,
) -> Tensor:
    if opts is None:
        opts = CurvatureOptions()
    comp = curvature_components(q_fn, states, actions, opts)
    if opts.mode == "action":
        grad = comp["g_a"]
        conc = comp["c_a"]
        k = grad + opts.c * conc
    elif opts.mode == "state":
        grad = comp["g_s"]
        conc = comp["c_s"]
        k = grad + opts.c * conc
    else:
        wa, ws = opts.weights
        grad = wa * comp["g_a"] + ws * comp.get("g_s", torch.zeros_like(comp["g_a"]))
        conc = wa * comp["c_a"] + ws * comp.get("c_s", torch.zeros_like(comp["c_a"]))
        k = grad + opts.c * conc
    if opts.clamp_min is not None or opts.clamp_max is not None:
        lo = opts.clamp_min if opts.clamp_min is not None else float("-inf")
        hi = opts.clamp_max if opts.clamp_max is not None else float("inf")
        k = torch.clamp(k, lo, hi)
    return k.detach() if opts.detach else k


def batch_curvature(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    opts: Optional[CurvatureOptions] = None,
    chunk_size: Optional[int] = None,
) -> Dict[str, Tensor]:
    if chunk_size is None:
        return curvature_components(q_fn, states, actions, opts)
    outs: Dict[str, Tensor] = {}
    keys: Optional[Iterable[str]] = None
    for i in range(0, states.shape[0], chunk_size):
        piece = curvature_components(q_fn, states[i : i + chunk_size], actions[i : i + chunk_size], opts)
        if keys is None:
            keys = piece.keys()
            for k in keys:
                outs[k] = []
        for k in keys:
            outs[k].append(piece[k])
    for k in list(outs.keys()):
        outs[k] = torch.cat(outs[k], dim=0)
    return outs


def summarize_curvature(comp: Dict[str, Tensor]) -> Dict[str, float]:
    out: Dict[str, float] = {}
    for k, v in comp.items():
        out[f"{k}/mean"] = float(v.mean().item())
        out[f"{k}/std"] = float(v.std().item()) if v.numel() > 1 else 0.0
        out[f"{k}/max"] = float(v.max().item())
        out[f"{k}/min"] = float(v.min().item())
    return out


def _demo():
    def q_fn(s: Tensor, a: Tensor) -> Tensor:
        return 0.5 * (a ** 2).sum(dim=-1) + (s * a).sum(dim=-1)

    B, S, A = 8, 3, 3
    s = torch.randn(B, S)
    a = torch.randn(B, A)
    comp = curvature_components(q_fn, s, a, CurvatureOptions(mode="joint", method="lanczos", steps=6, starts=2, use_trace=True))
    print(summarize_curvature(comp))


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
