import torch.nn.functional as F
import torch.nn as nn
import torch
if __name__=="__main__":

    class MyModel(torch.nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()

            self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, bias=False)
            self.d1 = torch.nn.Linear(4, 1, bias=False)

        def forward(self, x):
            x = self.conv1(x)
            # x = F.relu(x)

            x = x.flatten(start_dim = 1)

            x = self.d1(x)
            
            return x
    x = torch.rand((1,1, 3, 3))
    target= torch.ones([1],dtype=torch.long)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MyModel()
    model = model.to(device)
    y = model(x.to(device))
    loss = target.to(device) - y 
    env_grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True)

    def loss_func(y, target):
        return target.to(device) - y 
    model.train()
    from AgA import AgA_Grad, AgA_AdamW
    H_tool = AgA_Grad(model)
    aga_grad = H_tool.forward(loss,loss)

    print(aga_grad)
    for name, param in model.named_parameters():
        print(name, param)
    optimizer = AgA_AdamW(model.parameters())
    optimizer.zero_grad()
    optimizer.step(aga_grad)
    
    for name, param in model.named_parameters():
        print(name, param)