import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# from sklearn.decomposition import PCA
from torch.autograd import Variable
from torch.nn import init
# make fake data
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
import os
import sys
import platform
import shutil
from tqdm import tqdm
import pickle
from sklearn.decomposition import PCA


def mkdir(fn):  # Create a directory
    if not os.path.isdir(fn):
        os.mkdir(fn)


def save_fig(pltm, fntmp, fp=0, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        fp.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()


class Net(torch.nn.Module):
    def __init__(self, m, t):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(4, m)
        init.normal_(self.l1.weight, 0, 1/m**(t))
        init.normal_(self.l1.bias, 0, 1/m**(t))
        self.l2 = torch.nn.Linear(m, 3,bias=False)
        init.normal_(self.l2.weight, 0, 1/m**(t))

    def forward(self, x):
        x = F.sigmoid(self.l1(x))
        return self.l2(x)


def one_hot(x, class_count):
    return torch.eye(class_count)[x, :]


def hessian(loss, network_param):
    """
    Calculates the full Hessian of a Neural Network
    Args:
      loss: The calculated loss value
      network_param: A generator containing the parameters of the neural network


    """
    # loss should be entry of the form loss = loss_fn(out, y)
    # network_param should be my_net.parameters()

    param_list = [param for param in network_param]
    first_derivative = torch.autograd.grad(loss, param_list, create_graph=True)
    derivative_tensor = torch.cat([tensor.flatten()
                                  for tensor in first_derivative])
    # print('derivative_tenso')
    # print(derivative_tensor)
    num_parameters = derivative_tensor.shape[0]
    hessian = torch.zeros(num_parameters, num_parameters)

    for col_ind in range(num_parameters):
        # print(col_ind)
        jacobian_vec = torch.zeros(num_parameters).to(device)
        jacobian_vec[col_ind] = 1.
        if not col_ind == 0:
            for param in param_list:
                param.grad.zero_()
        # my_net.zero_grad()
        derivative_tensor.backward(jacobian_vec, retain_graph=True)
        hessian_col = torch.cat([param.grad.flatten() for param in param_list])
        hessian[:, col_ind] = hessian_col
    return hessian


lenarg = np.shape(sys.argv)[
    0]  # Sys.argv[ ]其实就是一个列表，里边的项为用户输入的参数，关键就是要明白这参数是从程序外部输入的，而非代码本身的什么地方，要想看到它的效果就应该将程序保存了，从外部来运行程序并给出参数。
if lenarg > 1:
    ilen = 1
    while ilen < lenarg:
        if sys.argv[ilen] == '-m':
            m = np.int32(sys.argv[ilen + 1])
        if sys.argv[ilen] == '-g':
            d = np.int32(sys.argv[ilen + 1])
        if sys.argv[ilen] == '-t':
            t = np.float32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-s':
        #     R['train_size']=np.int32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-lr':
        #     R['learning_rate']=np.float32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-dir':
        #     sBaseDir=sys.argv[ilen+1]
        ilen = ilen + 2

# path_ori = '/home/dir/data/loss_landscape/test100'
# mkdir(path_ori)


device = torch.device("cuda:%s" % (d) if torch.cuda.is_available() else "cpu")
iris = load_iris()
iris_d = pd.DataFrame(iris['data'], columns=[
                      'Sepal_Length', 'Sepal_Width', 'Petal_Length', 'Petal_Width'])
iris_d['Species'] = iris.target

# transfer_1 = PCA(n_components=2)
# iris_d = transfer_1.fit_transform(iris_d)
# x = torch.from_numpy(np.array(iris_d))
# y =torch.from_numpy(iris.target)
# x, y = Variable(x), Variable(y)

x = torch.from_numpy(iris.data)
y = torch.from_numpy(iris.target)
x, y = Variable(x), Variable(y)
x, y = x.to(device), y.to(device)
# print(y.shape)

net = Net(m, t).to(device)
print(net)
PATH = '/home/dir/data/loss_landscape/test100/2/6.0/118516'
load_dir = '%s/model_fin.ckpt' % (PATH)
# Path1=torch.load(load_dir,map_location='cuda:1')
Path1 = torch.load(load_dir, map_location=device)
net.load_state_dict(Path1)
# torch.save(net.state_dict(), "%smodel_ini.ckpt" % (path))
optimizer = torch.optim.SGD(net.parameters(), lr=0.2)
# the target label is NOT an one-hotted
# loss_func = torch.nn.CrossEntropyLoss()
loss_func = torch.nn.MSELoss(reduction='mean')


correct = 0
total = 0
optimizer.zero_grad()   # clear gradients for next train
out = net(x.float())                 # input x and predict based on x
out = torch.nn.functional.softmax(out)
target_onehot = one_hot(y, 3).to(device)
# must be (1. nn output, 2. target), the target label is NOT one-hotted
loss = loss_func(out, target_onehot)
# print(hessian(loss, net.parameters()))
# print(net.l1.weight)
Hessian=hessian(loss, net.parameters())
e=np.linalg.eig(Hessian)
print(e[0])
np.savetxt('%s/eig2.txt'%(PATH),e[0])


para_dict = net.state_dict()
l2weight=para_dict['l2.weight']
l1weight=para_dict['l1.weight']
l1bias=para_dict['l1.bias']
print(l2weight.shape)
print(l1weight.shape)
print(l1bias.shape)


# l2weightnew=torch.cat((l2weight[:,0].unsqueeze(1),l2weight[:,1].unsqueeze(1),0.5*l2weight[:,2].unsqueeze(1),0.5*l2weight[:,2].unsqueeze(1)),1)
# l1weightnew=torch.cat((l1weight[:,0].unsqueeze(1),l1weight[:,1].unsqueeze(1),l1weight[:,2].unsqueeze(1),l1weight[:,2].unsqueeze(1)),1)
# print(l1bias[0])
# l1biasnew=torch.stack((l1bias[0],l1bias[1],l1bias[2],l1bias[2]))

l2weightnew=torch.cat((l2weight[:,0].unsqueeze(1),1/3*l2weight[:,1].unsqueeze(1),1/3*l2weight[:,1].unsqueeze(1),1/3*l2weight[:,1].unsqueeze(1)),1)
l1weightnew=torch.cat((l1weight[0,:].unsqueeze(0),l1weight[1,:].unsqueeze(0),l1weight[1,:].unsqueeze(0),l1weight[1,:].unsqueeze(0)),0)
# print(l1bias[0])
l1biasnew=torch.stack((l1bias[0],l1bias[1],l1bias[1],l1bias[1]))



# l2weightnew=torch.cat((torch.zeros_like(l2weight[:,0]).unsqueeze(1),torch.zeros_like(l2weight[:,1]).unsqueeze(1),torch.zeros_like(l2weight[:,1]).unsqueeze(1)),1)
# l1weightnew=torch.cat((torch.zeros_like(l1weight[0,:]).unsqueeze(0),torch.zeros_like(l1weight[1,:]).unsqueeze(0),torch.zeros_like(l1weight[1,:]).unsqueeze(0)),0)
# # print(l1bias[0])
# # l1biasnew=torch.stack((l1bias[0],l1bias[1],torch.tensor(0).to(device)))
# l1biasnew=torch.stack((torch.tensor(0).to(device),torch.tensor(0).to(device),torch.tensor(0).to(device)))

# print(l1biasnew)
# print(l1bias)

para_dict['l2.weight']=l2weightnew
para_dict['l1.weight']=l1weightnew
para_dict['l1.bias']=l1biasnew
print(para_dict['l2.weight'].shape)
print(para_dict['l1.weight'].shape)
print(para_dict['l1.bias'].shape)
# print(l2weightnew)
# print(l2weight)
net_split=Net(m+2, t).to(device)
net_split.load_state_dict(para_dict)
optimizer.zero_grad()   # clear gradients for next train
out = net_split(x.float())                 # input x and predict based on x
out = torch.nn.functional.softmax(out)
target_onehot = one_hot(y, 3).to(device)
# must be (1. nn output, 2. target), the target label is NOT one-hotted
loss = loss_func(out, target_onehot)
# print(hessian(loss, net_split.parameters()))
# Hessian=hessian(loss, net_split.parameters())
# print(Hessian)
# print(Hessian[-1,0])
# print(Hessian[0,31])
# np.savetxt('/home/dir/data/loss_landscape/test100/3/6.0/171806/hess.txt',Hessian.detach().cpu().numpy())
# print(Hessian.shape)
# print(hessian(loss, net_split.parameters()))
e=np.linalg.eig(hessian(loss, net_split.parameters()))
# print(hessian(loss, net_split.parameters()))
print(e[0])
np.savetxt('%s/eig4.txt'%(PATH),e[0].real)