from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Literal, Optional, Tuple

import torch

from .hvp import hvp


Tensor = torch.Tensor


@dataclass
class LanczosOptions:
    steps: int = 10
    tol: float = 1e-8
    wrt: Literal["state", "action"] = "action"
    reorthogonalize: bool = True
    tikhonov: float = 0.0
    shift: float = 0.0
    restarts: int = 0
    seed: Optional[int] = None
    use_double: bool = False
    return_vecs: bool = False


def _normalize(v: Tensor, eps: float = 1e-12) -> Tensor:
    return v / (v.norm(dim=-1, keepdim=True) + eps)


def _mgs(v: Tensor, basis: Tensor, k: int) -> Tensor:
                                                                                 
    for i in range(k):
        vi = basis[:, i]
        proj = (v * vi).sum(dim=-1, keepdim=True) * vi
        v = v - proj
    return v


def lanczos(
    matvec: Callable[[Tensor], Tensor],
    v0: Tensor,
    steps: int,
    tol: float = 1e-8,
    reorth: bool = True,
    tikhonov: float = 0.0,
    shift: float = 0.0,
    return_vecs: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
    B, D = v0.shape
    V = torch.zeros(B, steps, D, device=v0.device, dtype=v0.dtype)
    T_diag = torch.zeros(B, steps, device=v0.device, dtype=v0.dtype)
    T_off = torch.zeros(B, steps, device=v0.device, dtype=v0.dtype)

    v = _normalize(v0)
    w_prev = torch.zeros_like(v)
    for j in range(steps):
        V[:, j] = v
        w = matvec(v)
        if shift != 0.0:
            w = w - shift * v
        alpha = (v * w).sum(dim=-1)
        T_diag[:, j] = alpha
        if j > 0:
            w = w - T_off[:, j - 1].unsqueeze(-1) * w_prev
        w = w - alpha.unsqueeze(-1) * v
        if reorth and j > 0:
                                                       
            for k in range(j):
                vk = V[:, k]
                proj = (w * vk).sum(dim=-1, keepdim=True) * vk
                w = w - proj
        beta = w.norm(dim=-1)
        if j < steps - 1:
            T_off[:, j] = beta
                                       
        if torch.all(beta < tol):
                                                             
            for r in range(j + 1, steps):
                T_diag[:, r] = T_diag[:, j]
                if r < steps - 1:
                    T_off[:, r] = 0.0
            break
        w_prev = v
        v = _normalize(w)

                                                           
    evals = []
    evecs_store = []
    for b in range(B):
        k = steps
        Tb = torch.diag(T_diag[b])
        if steps > 1:
            off = T_off[b][:-1]
            Tb += torch.diag(off, diagonal=1)
            Tb += torch.diag(off, diagonal=-1)
        if tikhonov != 0.0:
            Tb = Tb + torch.eye(k, device=v0.device, dtype=v0.dtype) * tikhonov
        eval_b, evec_b = torch.linalg.eigh(Tb)
        evals.append(eval_b)
        if return_vecs:
            evecs_store.append(evec_b)
    evals = torch.stack(evals, dim=0)
    if return_vecs:
        evecs = torch.stack(evecs_store, dim=0)
        return evals, evecs
    return evals, None


def batched_lanczos_min(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    steps: int = 10,
    tol: float = 1e-8,
    wrt: Literal["state", "action"] = "action",
) -> Tensor:
    if wrt == "action":
        var = actions.clone().detach().requires_grad_(True)
        other = states.detach()
        def mv(v: Tensor) -> Tensor:
            return hvp(q_fn, other, var, v, wrt="action", create_graph=False)
    else:
        var = states.clone().detach().requires_grad_(True)
        other = actions.detach()
        def mv(v: Tensor) -> Tensor:
            return hvp(q_fn, var, other, v, wrt="state", create_graph=False)

    v0 = torch.randn_like(var)
    evals, _ = lanczos(mv, v0, steps=steps, tol=tol, reorth=True, tikhonov=0.0, shift=0.0, return_vecs=False)
    return evals[:, 0]


def batched_power_min(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    iters: int = 20,
    wrt: Literal["state", "action"] = "action",
) -> Tensor:
    if wrt == "action":
        var = actions.clone().detach().requires_grad_(True)
        other = states.detach()
        def mv(v: Tensor) -> Tensor:
            return hvp(q_fn, other, var, v, wrt="action", create_graph=False)
    else:
        var = states.clone().detach().requires_grad_(True)
        other = actions.detach()
        def mv(v: Tensor) -> Tensor:
            return hvp(q_fn, var, other, v, wrt="state", create_graph=False)

    B, D = var.shape
    v = torch.randn_like(var)
    v = _normalize(v)
    lam = torch.zeros(B, device=var.device, dtype=var.dtype)
    for _ in range(iters):
        hv = mv(v)
                                                     
        w = -hv
                                 
        lam = (v * hv).sum(dim=-1) / (v.norm(dim=-1) ** 2 + 1e-12)
        v = _normalize(w)
    return lam


def multi_start_lanczos(
    q_fn: Callable[[Tensor, Tensor], Tensor],
    states: Tensor,
    actions: Tensor,
    starts: int = 3,
    opts: Optional[LanczosOptions] = None,
) -> Tensor:
    if opts is None:
        opts = LanczosOptions()
    if opts.wrt == "action":
        var = actions.clone().detach().requires_grad_(True)
        other = states.detach()
        def mv(v: Tensor) -> Tensor:
            return hvp(q_fn, other, var, v, wrt="action", create_graph=False)
    else:
        var = states.clone().detach().requires_grad_(True)
        other = actions.detach()
        def mv(v: Tensor) -> Tensor:
            return hvp(q_fn, var, other, v, wrt="state", create_graph=False)

    best = None
    for k in range(starts):
        v0 = torch.randn_like(var)
        evals, _ = lanczos(mv, v0, steps=opts.steps, tol=opts.tol, reorth=opts.reorthogonalize, tikhonov=opts.tikhonov, shift=opts.shift, return_vecs=False)
        lam_min = evals[:, 0]
        best = lam_min if best is None else torch.minimum(best, lam_min)
    return best


def _demo():
                                                           
    def q_fn(s: Tensor, a: Tensor) -> Tensor:
        return 0.5 * (a ** 2).sum(dim=-1)

    B, S, A = 4, 2, 3
    s = torch.randn(B, S)
    a = torch.randn(B, A)
    lam_l = batched_lanczos_min(q_fn, s, a, steps=6, tol=1e-9, wrt="action")
    lam_p = batched_power_min(q_fn, s, a, iters=16, wrt="action")
    print("lanczos", lam_l)
    print("power  ", lam_p)


if __name__ == "__main__":
    _demo()
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
