import torch

class FAGH:
    def __init__(self, lr=0.01, eps=0.1, beta1=0.0, beta2=0.0, device='cuda'):
        self.lr = lr
        self.eps = eps
        self.beta1 = beta1
        self.beta2 = beta2
        self.device = device
        self.old_v_dict = {}
        self.old_g_dict = {}

    def sherman(self, v, g):
        grad = g.view(-1)
        z = v / v[0]
        t1 = torch.matmul(v, z.t())
        t2 = torch.matmul(v, grad)
        Hg = grad / self.eps - (z.t() * t2 / self.eps) / (self.eps + t1)
        Hg = torch.reshape(Hg, g.shape)
        return Hg

    def train(self, model, dataloader, criterion, epochs):
        model.train()
        v_dict = {}
        g_dict = {}
        for _ in range(epochs):
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)
                model.zero_grad()
                out = model(x)
                loss = criterion(out, y)
                for name, param in model.named_parameters():
                    if not param.requires_grad:
                        continue
                    v, g = self._hessian(loss, param)
                    g = g  + self.eps * param.data
                    Hg = self.sherman(v, g)
                    param.data -= self.lr * Hg
                    with torch.no_grad():
                        if name not in v_dict:
                            v_dict[name] = torch.zeros_like(v)
                            g_dict[name] = torch.zeros_like(g)
                        v_dict[name] += v
                        g_dict[name] += g
        return v_dict, g_dict

    def _hessian(self, loss, param):
        # param = model.parameters()
        g = torch.autograd.grad(loss, param, create_graph = True)[0]
        D = g.numel()
        grad = g.view(-1)
        v = torch.autograd.grad(grad[0], param, retain_graph=True, create_graph=False)[0].view(-1)
        return v, g
