import torch
import torch.nn as nn
import torch.autograd.functional as func
import torch.nn.functional as F
import numpy as np
from functorch import make_functional_with_buffers

def product(xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    if type(xs) == list:
        return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)]).cpu()
    else:
        return torch.sum(xs * ys).cpu()

def normalise(v):
    """
    normalization of a list of vectors
    return: normalized vectors v
    """
    if type(v) == list:
        s = product(v, v)
        s = s**0.5
        s = s.cpu().item()
        v = [vi / (s + 1e-6) for vi in v]
        return v
    else:
        return v / torch.linalg.vector_norm(v)

def add(xs, ys):
    if type(xs) == list:
        return [x + y for x, y in zip(xs, ys)]
    else:
        return xs + ys

def scale(beta, xs):
    if type(xs) == list:
        return [beta*x for x in xs]
    else:
        return beta*xs

class derivatives():
    def __init__(self, net, criterion, input_shape, output_shape, device):
        self.net = net
        self.criterion = criterion
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.device = device

        self.data = (torch.randn(input_shape).to(self.device), torch.randn(output_shape).to(self.device))

        pfmap = make_functional_with_buffers(self.net)[0]

        def pf_map(*params):
            return pfmap(params, self.net.buffers(), self.data[0])
        
        self.pf_map = pf_map

        self.children = list(net.children())
        self.first_layer_features = self.children[0](self.data[0])

        def forward_from_first(features):
            for child in self.children:
                features = child(features)
            return features
        
        def forward_from_second(features):
            for child in self.children[1:]:
                features = child(features)
            return features
        
        def loss(*params):
            return self.criterion(self.pf_map(*params), self.data[1])
            
        def cost(features):
            return self.criterion(features, self.data[1])
        
        self.loss = loss
        self.cost = cost
        self.forward_from_first = forward_from_first
        self.forward_from_second = forward_from_second

        # self.Hvec = [torch.randn(p.shape).to(self.device) for p in self.net.parameters()]
        # self.Hvec = normalise(self.Hvec)

        # self.GNvec = [torch.randn(p.shape).to(self.device) for p in self.net.parameters()]
        # self.GNvec = normalise(self.GNvec)

        # self.HmGNvec = [torch.randn(p.shape).to(self.device) for p in self.net.parameters()]
        # self.HmGNvec = normalise(self.HmGNvec)

        # self.DFtDFvec = [torch.randn(p.shape).to(self.device) for p in self.net.parameters()]
        # self.DFtDFvec = normalise(self.DFtDFvec)

        # self.NTKvec = torch.randn(self.output_shape).to(self.device)
        # self.NTKvec = normalise(self.NTKvec)

        # self.jac1trainvec = torch.randn(self.output_shape).to(self.device)
        # self.jac1trainvec = normalise(self.jac1trainvec)

        # self.jac2trainvec = torch.randn(self.output_shape).to(self.device)
        # self.jac2trainvec = normalise(self.jac2trainvec)

        # self.jac1evalvec = torch.randn(self.output_shape).to(self.device)
        # self.jac1evalvec = normalise(self.jac1evalvec)

        # self.jac2evalvec = torch.randn(self.output_shape).to(self.device)
        # self.jac2evalvec = normalise(self.jac2evalvec)

        # def vec(mode):
        #     if mode == 'H':
        #         return self.Hvec
        #     elif mode == 'GN':
        #         return self.GNvec
        #     elif mode == 'H-GN':
        #         return self.HmGNvec
        #     elif mode == 'NTK':
        #         return self.NTKvec
        #     elif mode == 'jac1train':
        #         return self.jac1trainvec
        #     elif mode == 'jac2train':
        #         return self.jac2trainvec
        #     elif mode == 'jac1eval':
        #         return self.jac1evalvec
        #     elif mode == 'jac2eval':
        #         return self.jac2evalvec
            
        # self.vec = vec

        # def vec_update(mode, w):
        #     if mode == 'H':
        #         self.Hvec = w
        #     elif mode == 'GN':
        #         self.GNvec = w
        #     elif mode == 'H-GN':
        #         self.HmGNvec = w
        #     elif mode == 'NTK':
        #         self.NTKvec = w
        #     elif mode == 'jac1train':
        #         self.jac1trainvec = w
        #     elif mode == 'jac2train':
        #         self.jac2trainvec = w
        #     elif mode == 'jac1eval':
        #         self.jac1evalvec = w
        #     elif mode == 'jac2eval':
        #         self.jac2evalvec = w
            
        # self.vec_update = vec_update

        def vec(mode):
            if mode in ('H', 'GN', 'H-GN'):
                v = [torch.randn(p.shape).to(self.device) for p in self.net.parameters()]
                return normalise(v)
            elif mode in ('NTK', 'jac1train', 'jac2train', 'jac1eval', 'jac2eval'):
                v = torch.randn(self.output_shape).to(self.device)
                return normalise(v)
            
        self.vec = vec

    def update(self, data):
        self.data = data
        self.first_layer_features = self.children[0](self.data[0])

    def Av(self, v, mode):
        if mode == 'H':
            self.net.train()
            loss = self.criterion(self.net(self.data[0]), self.data[1])
            grads = list(torch.autograd.grad(loss, tuple(self.net.parameters()), create_graph=True))
            dot = product(grads, v)
            return list(torch.autograd.grad(dot, tuple(self.net.parameters())))
        elif mode == 'GN':
            self.net.train()
            outputs, DFv = func.jvp(self.pf_map, tuple(self.net.parameters()), tuple(v))
            HcDFv = func.hvp(self.cost, outputs, DFv)[1]
            return list(func.vjp(self.pf_map, tuple(self.net.parameters()), HcDFv)[1])
        elif mode == 'NTK':
            self.net.train()
            vtDF = func.vjp(self.pf_map, tuple(self.net.parameters()), v)[1]
            return func.jvp(self.pf_map, tuple(self.net.parameters()), vtDF)[1]
        elif mode == 'H-GN':
            self.net.train()
            loss = self.criterion(self.net(self.data[0]), self.data[1])
            grads = list(torch.autograd.grad(loss, tuple(self.net.parameters()), create_graph=True))
            dot = product(grads, v)
            hvp = list(torch.autograd.grad(dot, tuple(self.net.parameters())))
            outputs, DFv = func.jvp(self.pf_map, tuple(self.net.parameters()), tuple(v))
            HcDFv = func.hvp(self.cost, outputs, DFv)[1]
            GNv = func.vjp(self.pf_map, tuple(self.net.parameters()), HcDFv)[1]
            return [w1 - w2 for w1, w2 in zip(list(hvp), list(GNv))]
        elif mode == 'jac1train':
            self.net.train()
            vtJF = func.vjp(self.forward_from_first, self.data[0], v)[1]
            return func.jvp(self.forward_from_first, self.data[0], vtJF)[1]
        elif mode == 'jac2train':
            self.net.train()
            vtJF = func.vjp(self.forward_from_second, self.first_layer_features, v)[1]
            return func.jvp(self.forward_from_second, self.first_layer_features, vtJF)[1]
        elif mode == 'jac1eval':
            self.net.eval()
            vtJF = func.vjp(self.forward_from_first, self.data[0], v)[1]
            return func.jvp(self.forward_from_first, self.data[0], vtJF)[1]
        elif mode == 'jac2eval':
            self.net.eval()
            vtJF = func.vjp(self.forward_from_second, self.first_layer_features, v)[1]
            return func.jvp(self.forward_from_second, self.first_layer_features, vtJF)[1]

    def power(self, mode, num_iters = 100, tol = 1e-3):
        
        eigenvalue = None

        v = self.vec(mode)

        for i in range(num_iters):
            w = self.Av(v, mode)
            tmp_eigenvalue = product(v, w)
            v = normalise(w)

            if eigenvalue == None:
                eigenvalue = tmp_eigenvalue
            else:
                if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol:
                    break
                else: eigenvalue = tmp_eigenvalue

        return eigenvalue.cpu().numpy()
        
    # def power(self, mode, num_iters = 100, tol = 1e-3):

    #     print(self.net.training)
        
    #     eigenvalue = None

    #     for i in range(num_iters):
    #         print(i)
    #         w = self.Av(self.vec(mode), mode)
    #         tmp_eigenvalue = product(self.vec(mode), w)
    #         self.vec_update(mode, normalise(w))

    #         if eigenvalue == None:
    #             eigenvalue = tmp_eigenvalue
    #         else:
    #             if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) + 1e-6) < tol:
    #                 break
    #             else: eigenvalue = tmp_eigenvalue

    #     return eigenvalue

# class Net(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.lin1 = nn.Linear(3072, 200)
#         self.relu1 = nn.ReLU()
#         self.lin2 = nn.Linear(200, 200)
#         self.relu2 = nn.ReLU()
#         self.lin3 = nn.Linear(200, 10)

#     def forward(self, x):
#         z1 = self.relu1(self.lin1(x))
#         z2 = self.relu2(self.lin2(z1))
#         z3 = self.lin3(z2)
#         return z3



# net = Net()
    
# pd = derivatives(net, nn.MSELoss(), [5000, 3072], [5000, 10], 'cpu')
# v = tuple([torch.randn(list(p.shape)) for p in net.parameters()])
# # print(pd.power('GN'))
# # print((2/50000)*pd.power('NTK'))
# print(pd.power('H'))

        
    
            
    
