import torch
import torch.nn as nn
from functorch import make_functional, vmap, jacrev
import numpy as np
from matplotlib import pyplot as plt
from NTK_func import NtkFcun
import argparse
from sklearn.preprocessing import normalize

parser = argparse.ArgumentParser(description='Calculate volume of a cylinder')
parser.add_argument('--weight', type=float, default = 3)
args = parser.parse_args()

device = 'cuda:0'

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     torch.backends.cudnn.deterministic = True
# 设置随机数种子



class NN(nn.Module):
    def __init__(self, layer_sizes):
        super(NN, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            m = nn.Linear(layer_sizes[i], layer_sizes[i + 1])
            self.linears.append(m)
        
    def forward(self, x):
        for linear in self.linears[:-1]:
            x = torch.relu(linear(x))
        x = self.linears[-1](x)
        return x 
    
    def get_gradient(self, x):
        x.requires_grad_()
        N,_ = x.shape
        u = self.forward(x)
        gradient_x = torch.zeros(N).to(device)
        gradient_y = torch.zeros(N).to(device)
        d_output = torch.ones_like(u, requires_grad=False, device= device)
       
        u_x = torch.autograd.grad(u, x, d_output, retain_graph=True, create_graph=True)[0]
        gradient_x = u_x[:,0]
        # gradient_y = u_x[:,1]
    
        return gradient_x
    
    def get_single_angle(self, x):
        x.requires_grad_()
        u = self.forward(x)
        f_thate = torch.autograd.grad(u, self.parameters(), retain_graph=True, create_graph=True)[0]
        f_x = torch.autograd.grad(u, x, retain_graph=True, create_graph=True)[0]
        
        f_x_thate = torch.autograd.grad(f_x[0,0], self.parameters(),allow_unused=True)[0]
        angle = torch.sum(f_thate*f_x_thate)/torch.sqrt(torch.sum(f_thate**2))/torch.sqrt(torch.sum(f_x_thate**2))
        return angle
    
    def get_angle(self,x):
        N,_ = x.shape
        angle = []
        for i in range(N):
            angle_i = self.get_single_angle(x[i,:].view(1,2))
            angle.append(angle_i.detach().cpu().numpy())
        return angle



#train f(x,y) = xy in which xy from sin(gama),cos(gama)
def pro2(model, samples, epochs, weight, mode = 'regular'):
    torch.cuda.empty_cache()
    opt_func = torch.optim.SGD
    optimizer = opt_func(model.parameters(), lr = 0.01)
    train_loss = []
    labels = samples[:,0]*samples[:,1]
    func_loss= lambda x,y: 0.5*torch.mean((x - y)**2)
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(samples)
        out = out.squeeze(-1)
        loss_regular = func_loss(out, labels)
        if mode == 'regular':
            loss = loss_regular
        elif mode == 'gradient':
            out_gx, _ = model.get_gradient(samples)
            loss_gradientx = func_loss(out_gx, samples[:,1])
            loss = (weight[0]*loss_regular + weight[1] * loss_gradientx)/(weight[0] + weight[1])
        loss.backward()
        optimizer.step()
        train_loss.append(loss.detach().cpu().numpy())
    metric = loss_regular + loss_gradientx
    print(f"train_loss:{metric}")


def weight_reset(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

def plot_kernel():
    n_size = 200
    layer_sizes = [1] + [n_size] * 3 + [1]
    net = NN(layer_sizes).to(device)
    num_sample = 100
    gama = torch.linspace(-3,3,num_sample).unsqueeze(1)
    # x = torch.stack((torch.cos(gama),torch.sin(gama)), dim=-1).to(device)
    x = gama.to(device)
    ntk_tool = NtkFcun(net)
    xt = x.clone()
    angle = ntk_tool.get_angle(x, xt)
    plt.plot(gama.detach().cpu().numpy(), angle, color = 'r', label = 'angle')

    layer_sizes = [2] + [n_size] * 3 + [1]
    net = NN(layer_sizes).to(device)
    num_sample = 100
    gama = torch.linspace(-3,3,num_sample)
    x = torch.stack((torch.cos(gama),torch.sin(gama)), dim=-1).to(device)
    ntk_tool = NtkFcun(net)
    xt = x.clone()
    angle = ntk_tool.get_angle(x, xt)
    plt.plot(gama.detach().cpu().numpy(), angle, color = 'b', label = 'angle')
    plt.savefig('./sample/angle.png')



def pro3():
    version = 10
    test_weight = [0.2, 0.4, 0.8, 1.6, 3.2, 3.5, 4, 5.4,
                6.2, 6.6, 6.8, 10, 20, 35, 50, 75, 100, 200]
    data_record = np.zeros((version,len(test_weight)))
    for i in range(version):
        setup_seed(i)
        for j in range(len(test_weight)):
            data_record[i][j] = _pro3(test_weight[j])
    np.save("./data_process/" + "adweight" + ".npy", data_record)

def _pro3(w):
    gama = torch.linspace(-3,3,100)
    x = torch.stack((torch.cos(gama),torch.sin(gama)), dim=-1).to(device)
    n_size = 200
    layer_sizes = [2] + [n_size] * 3 + [1]
    net = NN(layer_sizes).to(device)
    ntk_tool = NtkFcun(net)
    xt = x.clone()
    weight = ntk_tool.get_weight(x, xt)
    weight[0] = 1.37
    weight[1] = w
    # np.save("./data_process/" + "angle_" + ".npy", angle)
    print(weight)
    epochs = 20000
    torch.cuda.empty_cache()
    opt_func = torch.optim.SGD
    optimizer = opt_func(net.parameters(), lr = 0.01)
    train_loss = []
    labels = x[:,0]*x[:,1]
    func_loss= lambda x,y: 0.5*torch.mean((x - y)**2)
    for epoch in range(epochs):
        net.train()
        optimizer.zero_grad()
        out = net(x)
        out = out.squeeze(-1)
        loss_regular = func_loss(out, labels)
        out_gx, _ = net.get_gradient(x)
        loss_gradientx = func_loss(out_gx, x[:,1])
        loss = (weight[0]*loss_regular + weight[1] * loss_gradientx)/(weight[0] + weight[1])
        loss.backward()
        optimizer.step()
        train_loss.append(loss.detach().cpu().numpy())
    conv_loss = loss_regular + loss_gradientx
    print(f"train_loss:{conv_loss}")
    return conv_loss

def pro4(version = 10):
    epochs = 1
    n_size = 200
    layer_sizes = [1] + [n_size] * 3 + [1]
    net = NN(layer_sizes).to(device)
    num_sample = 100
    gama = torch.linspace(-3,3,num_sample).unsqueeze(1)
    # x = torch.stack((torch.cos(gama),torch.sin(gama)), dim=-1).to(device)
    x = gama.to(device)
    torch.cuda.empty_cache()
    opt_func = torch.optim.SGD
    optimizer = opt_func(net.parameters(), lr = 0.01)
    labels = -10
    func_loss= lambda x,y: 0.5*torch.mean((x - y)**2)
    ini_gra = np.zeros((version,100))
    ini_out = np.zeros((version,100))
    trained_gra = np.zeros((version,100))

    for k in range(version):
        for i in range(len(x)):
            net.apply(weight_reset)
            optimizer = opt_func(net.parameters(), lr = 0.1)
            ini_out[k,i] = net(x[i])
            ini_gra[k,i] = net.get_gradient(x[i].unsqueeze(1)).detach().cpu().numpy()
            for epoch in range(epochs):
                net.train()
                optimizer.zero_grad()
                out = net(x[i])
                loss= func_loss(out, labels)
                loss.backward()
                optimizer.step()
            trained_gra[k,i] = net.get_gradient(x[i].unsqueeze(1)).detach().cpu().numpy()
            print("{}/100".format(i))

    labels = 10
    for k in range(version):
        for i in range(len(x)):
            net.apply(weight_reset)
            optimizer = opt_func(net.parameters(), lr = 0.1)
            ini_gra[k,i] = net.get_gradient(x[i,:].unsqueeze(0)).detach().cpu().numpy()
            for epoch in range(epochs):
                net.train()
                optimizer.zero_grad()
                out = net(x[i])
                loss= func_loss(out, labels)
                loss.backward()
                optimizer.step()
            trained_gra[k,i] = net.get_gradient(x[i,:].unsqueeze(0)).detach().cpu().numpy()
            print("{}/100".format(i))


def _pro4_star(mode = 'adaptive', epochs = 10000):
    train_num = 100
    gama = torch.rand((train_num))
    x = torch.stack((torch.cos(gama),torch.sin(gama)), dim=-1).to(device)
    # x = gama.to(device)
    print(x.shape)
    n_size = 200
    layer_sizes = [2] + [n_size] * 3 + [1]
    net = NN(layer_sizes).to(device)
    ntk_tool = NtkFcun(net)
    xt = x.clone()
    dis_mat ,batch_index = ntk_tool.get_batch(x,xt)
    dis_mat = dis_mat.detach().cpu().numpy()
    torch.cuda.empty_cache()
    opt_func = torch.optim.SGD
    optimizer = opt_func(net.parameters(), lr = 0.0001)
    train_loss = []
    labels = x[:,0]*x[:,1]
    func_loss= lambda x,y: 0.5*torch.mean((x - y)**2)
    batch = 2
    for epoch in range(epochs):
        net.train()
        optimizer.zero_grad()
        for i in range(100):
            if mode == 'adaptive':
                index = np.argsort(dis_mat[i,:])[::-1]
                neibour = np.sort(np.concatenate((np.array([i]),np.array(index[0:batch - 1]))))
                x_batch = x[neibour,:].squeeze(0)
                labels_batch = labels[neibour]
            elif mode == 'normal':
                x_batch = x[i:(i+batch - 1)%100,:].squeeze(0)
                labels_batch = labels[i:(i+batch - 1)%100]

            out = net(x_batch)
            
            out = out.squeeze(-1)
            loss_regular = func_loss(out, labels_batch)
            loss = loss_regular
            loss.backward()
            
            optimizer.step()
        out = net(x).squeeze(-1)
        loss_epochs = func_loss(out, labels)
        train_loss.append(loss_epochs.detach().cpu().numpy())
        print(f"Epoch[{epoch+1}/{epochs}], train_loss:{loss_epochs}")
    print(batch_index)
    return train_loss

def pro4_star():
    version = 10
    epochs = 40000
    apalower_batch_loss = np.zeros((version,epochs))
    norm_batch_loss = np.zeros((version, epochs))
    for i in range(version):
        setup_seed(i)
        apalower_batch_loss[i,:] = _pro4_star('adaptive')
        norm_batch_loss[i,:] = _pro4_star('normal')
    np.save("./data_process/" + "batch_normal" + ".npy", norm_batch_loss)
    np.save("./data_process/" + "batch_adaptive" + ".npy", apalower_batch_loss)


if __name__ == '__main__':
    plot_kernel()