import torch

class FNNOW:
    def __init__(self, lr=0.001, eps=0.01, sampler=0.0001, clp=1e-4, beta=0.0, device='cuda'):
        self.lr = lr
        self.eps = eps
        self.sampler = sampler
        self.clp = clp
        self.beta = beta
        self.device = device
        self.old_Hg_dict = {}

    def train(self, model, dataloader, criterion, epochs):
        model.train()
        Hg_dict = {}
        for _ in range(epochs):
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                model.zero_grad()
                out = model(x)
                loss = criterion(out, y)
                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    HdD, g = self._hessian(loss, param)
                    g = g + self.eps * param.data
                    Hg = self._woodbury(g, HdD)
                    param.data -= self.lr * Hg
                    with torch.no_grad():
                        if name not in Hg_dict:
                            Hg_dict[name] = torch.zeros_like(Hg)
                        Hg_dict[name] += Hg
        return Hg_dict

    def _hessian(self, loss, param):
        g = torch.autograd.grad(loss, param, create_graph=True)[0]
        grad = g.view(-1)
        D = g.numel()
        d = int(self.sampler * D)
        d = min(d, D)
        d = max(1, d)
        HdD = torch.empty((d, D), requires_grad=False, device=self.device)
        g_squared = grad ** 2
        leverage_scores = g_squared / g_squared.sum()
        sampling_probabilities = leverage_scores / leverage_scores.sum()
        N_idx = torch.multinomial(sampling_probabilities, d, replacement=False)
        for i, idx in enumerate(N_idx):
            HdD[i,] = torch.autograd.grad(grad[idx], param, retain_graph=True)[0].view(-1)
        H_term = HdD[:, N_idx].clone()
        for i, idx in enumerate(N_idx):
            HdD[:, idx] = HdD[:, i]
        HdD[:, :d] = H_term
        return HdD, g

    def _woodbury(self, grad, HdD):
        g = grad.view(-1)
        Hdd = HdD[:, HdD.shape[0] - 1]
        H_mid = Hdd + (1 / self.eps) * (HdD @ HdD.T)
        U, S, Vh = torch.linalg.svd(H_mid)
        S = torch.sqrt(S**2 + self.clp)
        Hdd_inv = Vh.T @ torch.diag(1.0 / S) @ U.T
        p1 = HdD @ g
        p2 = Hdd_inv @ p1
        Hg = (1 / self.eps) * (HdD.T @ p2)
        Hg = g - Hg
        return Hg.view_as(grad)