from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Literal

import torch


Tensor = torch.Tensor


def directional_second_diff(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    direction: Tensor,
    wrt: Literal["state", "action"] = "action",
    eps: float = 1e-3,
) -> Tensor:
    if wrt == "action":
        a = actions
        s = states
        v = direction
        qpp = q_fn(s, a + eps * v)
        qmm = q_fn(s, a - eps * v)
        qp0 = q_fn(s, a)
    else:
        s = states
        a = actions
        v = direction
        qpp = q_fn(s + eps * v, a)
        qmm = q_fn(s - eps * v, a)
        qp0 = q_fn(s, a)
    return (qpp + qmm - 2.0 * qp0) / (eps ** 2)


def gradient_fd(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    eps: float = 1e-3,
) -> Tensor:
    if wrt == "action":
        grads = []
        for i in range(actions.shape[-1]):
            e = torch.zeros_like(actions)
            e[..., i] = 1.0
            grads.append((f(states, actions + eps * e) - f(states, actions - eps * e)) / (2 * eps))
        return torch.stack(grads, dim=-1)
    else:
        grads = []
        for i in range(states.shape[-1]):
            e = torch.zeros_like(states)
            e[..., i] = 1.0
            grads.append((f(states + eps * e, actions) - f(states - eps * e, actions)) / (2 * eps))
        return torch.stack(grads, dim=-1)


def hessian_diag_fd(
    f: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    eps: float = 1e-3,
) -> Tensor:
    if wrt == "action":
        diags = []
        for i in range(actions.shape[-1]):
            e = torch.zeros_like(actions)
            e[..., i] = 1.0
            diags.append(directional_second_diff(f, states, actions, e, wrt="action", eps=eps))
        return torch.stack(diags, dim=-1)
    else:
        diags = []
        for i in range(states.shape[-1]):
            e = torch.zeros_like(states)
            e[..., i] = 1.0
            diags.append(directional_second_diff(f, states, actions, e, wrt="state", eps=eps))
        return torch.stack(diags, dim=-1)


def _demo():
    def q_fn(s: Tensor, a: Tensor) -> Tensor:
        return 0.5 * (a ** 2).sum(dim=-1)

    s = torch.randn(4, 2)
    a = torch.randn(4, 3)
    g = gradient_fd(q_fn, s, a, wrt="action")
    h = hessian_diag_fd(q_fn, s, a, wrt="action")
    print(g.shape, h.shape)


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
