import torch
import torch.optim as optim
from util import calculate_residuals
torch.set_printoptions(precision=10)


class NGDOptimizer(optim.Optimizer):
    def __init__(self, model, lr):
        defaults = dict(lr=lr)
        super(NGDOptimizer, self).__init__(model.parameters(), defaults)
        self.model = model
        
    def step(self, x_f, x_b, y_b):
        self.zero_grad()

        sk_all, _ = calculate_residuals(self.model, x_f, x_b, y_b)
        
        gradients = []
        for i in range(sk_all.size(0)):
            sk_i = sk_all[i:i+1, :]
            grad_i = torch.autograd.grad(sk_i, self.model.parameters(), retain_graph=True)
            grad_flat = torch.cat([g.view(-1) for g in grad_i])
            gradients.append(grad_flat)   
        Jk = torch.stack(gradients)
        
# =============================================================================
#         def model_fn():
#             sk_all, _ = calculate_residuals(self.model, x_f, x_b, y_b)
#         return sk_all
#         
#         jacobian = torch.autograd.functional.jacobian(sk_all, self.model.parameters())
#         print(jacobian.shape)  # [batch_size, input_dim, batch_size, output_dim]
#         Jk = jacobian.diagonal(dim1=0, dim2=2).permute(2, 0, 1)  # [batch_size, input_dim, output_dim]
#         print(Jk.shape)
# =============================================================================
        
        matrix = torch.matmul(Jk.T, torch.pinverse(torch.matmul(Jk, Jk.T)))
        grad_next = torch.matmul(matrix, sk_all)

        idx = 0
        for param in self.model.parameters():
            param_size = param.numel()
            param.data -= self.param_groups[0]["lr"] * grad_next[idx:idx + param_size].view(param.shape)
            idx += param_size
            