from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Literal, Optional

import torch

from .hvp import hvp


Tensor = torch.Tensor


@dataclass
class HutchinsonOptions:
    samples: int = 8
    wrt: Literal["state", "action"] = "action"
    kind: Literal["rademacher", "normal"] = "rademacher"
    unbiased: bool = True


def _sample_z(shape, kind: str, device=None, dtype=None) -> Tensor:
    if kind == "rademacher":
        z = torch.randint(0, 2, shape, device=device)
        z = z.to(dtype or torch.float32) * 2 - 1
        return z
    return torch.randn(shape, device=device, dtype=dtype)


def hutchinson_trace(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    samples: int = 8,
) -> Tensor:
    if wrt == "action":
        var = actions.clone().detach().requires_grad_(True)
        other = states.detach()
    else:
        var = states.clone().detach().requires_grad_(True)
        other = actions.detach()
    B, D = var.shape
    tr = torch.zeros(B, device=var.device, dtype=var.dtype)
    for _ in range(samples):
        z = _sample_z(var.shape, "rademacher", device=var.device, dtype=var.dtype)
        if wrt == "action":
            hz = hvp(q_fn, other, var, z, wrt="action", create_graph=False)
        else:
            hz = hvp(q_fn, var, other, z, wrt="state", create_graph=False)
        tr = tr + (z * hz).sum(dim=-1)
    return tr / float(samples)


def hutchinson_diag(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    wrt: Literal["state", "action"] = "action",
    samples: int = 16,
) -> Tensor:
    if wrt == "action":
        var = actions.clone().detach().requires_grad_(True)
        other = states.detach()
    else:
        var = states.clone().detach().requires_grad_(True)
        other = actions.detach()
    B, D = var.shape
    diag = torch.zeros(B, D, device=var.device, dtype=var.dtype)
    for _ in range(samples):
        z = _sample_z(var.shape, "rademacher", device=var.device, dtype=var.dtype)
        hz = hvp(q_fn, other if wrt == "action" else var, var if wrt == "action" else other, z, wrt=wrt, create_graph=False)
        diag = diag + hz * z
    return diag / float(samples)


def _demo():
    def q_fn(s: Tensor, a: Tensor) -> Tensor:
        return 0.5 * (a ** 2).sum(dim=-1)

    s = torch.randn(4, 2)
    a = torch.randn(4, 3)
    print(hutchinson_trace(q_fn, s, a, wrt="action", samples=16))


if __name__ == "__main__":
    _demo()
                
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
