from __future__ import annotations

from typing import Callable, Tuple, Literal

import torch
import torch.nn as nn


def compute_hvp(
    q_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    states: torch.Tensor,
    actions: torch.Tensor,
    vector: torch.Tensor,
) -> torch.Tensor:
    """
    Batched Hessian-vector product H_a(Q) @ v w.r.t. actions.

    q_fn: (states[B, ...], actions[B, A]) -> q[B] (scalar per sample)
    actions: requires_grad=True
    vector: same shape as actions
    returns: hvp with shape [B, A]
    """
    assert actions.requires_grad, "actions must require grad for HVP"
    q = q_fn(states, actions)       
    q_sum = q.sum()
    grad_a = torch.autograd.grad(q_sum, actions, create_graph=True, retain_graph=True)[0]          
    dot = (grad_a * vector).sum()
    hvp = torch.autograd.grad(dot, actions, retain_graph=True)[0]
    return hvp


def compute_hvp_wrt(
    q_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    states: torch.Tensor,
    actions: torch.Tensor,
    vector: torch.Tensor,
    wrt: Literal["action", "state"] = "action",
) -> torch.Tensor:
    if wrt == "action":
        assert actions.requires_grad, "actions must require grad for HVP"
        q = q_fn(states, actions)
        grad_var = actions
    else:
        assert states.requires_grad, "states must require grad for HVP"
        q = q_fn(states, actions)
        grad_var = states
    q_sum = q.sum()
    grad = torch.autograd.grad(q_sum, grad_var, create_graph=True, retain_graph=True)[0]
    dot = (grad * vector).sum()
    hvp = torch.autograd.grad(dot, grad_var, retain_graph=True)[0]
    return hvp


def estimate_min_eigenvalue_lanczos(
    q_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    states: torch.Tensor,
    actions: torch.Tensor,
    num_steps: int = 6,
    tol: float = 1e-8,
    max_restarts: int = 1,
    wrt: Literal["action", "state"] = "action",
) -> torch.Tensor:
    """
    Very simple per-sample Lanczos to estimate min eigenvalue of H_a(Q).
    For simplicity and clarity, we loop over batch samples.
    Returns tensor [B] of estimated lambda_min.
    """
    device = actions.device
    B, A = actions.shape
    if wrt == "action":
        actions = actions.clone().detach().requires_grad_(True)
        states = states.detach()
    else:
        states = states.clone().detach().requires_grad_(True)
        actions = actions.detach()
    lambdas = []
    for i in range(B):
        a_i = actions[i : i + 1]
        s_i = states[i : i + 1]
                            
        var_shape = a_i.shape if wrt == "action" else s_i.shape
        v = torch.randn(var_shape, device=actions.device, dtype=actions.dtype)
        v = v / (v.norm(p=2) + 1e-12)
        alpha_list = []
        beta_list = []
        v_prev = torch.zeros_like(v)
        for _ in range(num_steps):
            hv = compute_hvp_wrt(q_fn, s_i, a_i, v, wrt=wrt)
            alpha = (v * hv).sum(dim=-1, keepdim=True)          
            w = hv - alpha * v - (beta_list[-1] * v_prev if beta_list else 0.0)
            beta = w.norm(p=2)                 
            if beta.item() < tol:
                alpha_list.append(alpha)
                beta_list.append(beta)
                break
                                             
            v_prev = v
            v = w / (beta + 1e-12)
            alpha_list.append(alpha)
            beta_list.append(beta)
                                                                   
        k = len(alpha_list)
        T = torch.zeros((k, k), device=device, dtype=actions.dtype)
        for j in range(k):
            T[j, j] = alpha_list[j].squeeze()
            if j + 1 < k:
                b = beta_list[j + 1] if (j + 1) < len(beta_list) else beta_list[j]
                T[j, j + 1] = b
                T[j + 1, j] = b
        evals = torch.linalg.eigvalsh(T)
        lambdas.append(evals[0].detach())
    return torch.stack(lambdas, dim=0)


def estimate_min_eigenvalue_power(
    q_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    states: torch.Tensor,
    actions: torch.Tensor,
    num_steps: int = 8,
    wrt: Literal["action", "state"] = "action",
) -> torch.Tensor:
    """
    Batched power iteration on -H to approximate min eigenvalue of H.
    Returns [B] of approximated lambda_min.
    """
    if wrt == "action":
        actions = actions.clone().detach().requires_grad_(True)
        states = states.detach()
    else:
        states = states.clone().detach().requires_grad_(True)
        actions = actions.detach()
    B, A = actions.shape
    var = actions if wrt == "action" else states
    v = torch.randn_like(var)
    v = v / (v.norm(dim=-1, keepdim=True) + 1e-12)
    lam = torch.zeros(B, device=actions.device, dtype=actions.dtype)
    for _ in range(num_steps):
        hv = compute_hvp_wrt(q_fn, states, actions, v, wrt=wrt)
                                                           
        w = -hv
                                          
        lam = (v * hv).sum(dim=-1) / (v.norm(dim=-1) ** 2 + 1e-12)
        v = w / (w.norm(dim=-1, keepdim=True) + 1e-12)
    return lam.detach()


def calculate_geometric_risk(
    q_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
    states: torch.Tensor,
    actions: torch.Tensor,
    curvature_weight: float = 1.0,
    lanczos_steps: int = 6,
    method: Literal["power", "lanczos"] = "power",
    mode: Literal["action", "state", "joint"] = "action",
    state_weight: float = 1.0,
    action_weight: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute kappa = ||grad_a Q||_2 + c * max(0, -lambda_min(H_a(Q))).
    Returns (kappa, grad_norm, concavity) each shaped [B].
    """
    def grad_norm_wrt(var: str):
        s = states.detach()
        a = actions.detach()
        if var == "action":
            a = a.clone().detach().requires_grad_(True)
            q = q_fn(s, a)
            g = torch.autograd.grad(q.sum(), a, create_graph=False, retain_graph=True)[0]
        else:
            s = s.clone().detach().requires_grad_(True)
            q = q_fn(s, a)
            g = torch.autograd.grad(q.sum(), s, create_graph=False, retain_graph=True)[0]
        return torch.norm(g, dim=-1)

    def lambda_min_wrt(var: str):
        if method == "power":
            return estimate_min_eigenvalue_power(q_fn, states, actions, num_steps=lanczos_steps, wrt=var)       
        else:
            return estimate_min_eigenvalue_lanczos(q_fn, states, actions, num_steps=lanczos_steps, wrt=var)

    if mode == "action":
        grad_norm = grad_norm_wrt("action")
        lam = lambda_min_wrt("action")
        concavity = torch.clamp(-lam, min=0.0)
        kappa = grad_norm + curvature_weight * concavity
    elif mode == "state":
        grad_norm = grad_norm_wrt("state")
        lam = lambda_min_wrt("state")
        concavity = torch.clamp(-lam, min=0.0)
        kappa = grad_norm + curvature_weight * concavity
    else:         
        g_a = grad_norm_wrt("action")
        g_s = grad_norm_wrt("state")
        lam_a = lambda_min_wrt("action")
        lam_s = lambda_min_wrt("state")
        c_a = torch.clamp(-lam_a, min=0.0)
        c_s = torch.clamp(-lam_s, min=0.0)
        grad_norm = action_weight * g_a + state_weight * g_s
        concavity = action_weight * c_a + state_weight * c_s
        kappa = grad_norm + curvature_weight * concavity

    return kappa.detach(), grad_norm.detach(), concavity.detach()
