import os,sys
import matplotlib
matplotlib.use('Agg')   
import pickle
import time  
import shutil 
import numpy as np
import matplotlib.pyplot as plt   
import platform
from mpl_toolkits.mplot3d import Axes3D 
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from datetime import datetime
from matplotlib.lines import Line2D
import torch
from torch.nn import init
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets  # 放置了许多常用数据集,包括手写数字识别
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader,Dataset,TensorDataset
# from tqdm import tqdm
import os, sys
from torch.nn import init

import os, sys
import time
import pickle
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
import math
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
import platform
import shutil



Leftp = 0.08
Bottomp = 0.18
Widthp = 0.78 - Leftp
Heightp = 0.9 - Bottomp
pos = [Leftp, Bottomp, Widthp, Heightp]

def save_fig(pltm, fntmp, fp=1, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        pltm.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()





def mkdir(fn):  # Create a directory
    if not os.path.isdir(fn):
        os.mkdir(fn)

def one_hot(x, class_count):
	return torch.eye(class_count)[x,:]

class Net(torch.nn.Module):
    def __init__(self,m,t):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, m)
        init.normal_(self.l1.weight, 0, 1/m**(t))
        init.normal_(self.l1.bias, 0, 1/m**(t))
        self.l2 = torch.nn.Linear(m, 10,bias=False)
        init.normal_(self.l2.weight, 0, 1/m**(t))
        # init.normal_(self.l2.bias, 0, 1/m**(t))
        # self.l3 = torch.nn.Linear(m, m)
        # init.normal_(self.l3.weight, 0, 1/m**(t))
        # init.normal_(self.l3.bias, 0, 1/m**(t))
        # self.l4 = torch.nn.Linear(m, 10,bias=False)
        # init.normal_(self.l4.weight, 0, 1/m**(t))
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.l1(x))
        # x = F.relu(self.l2(x))
        # x = F.relu(self.l3(x))
        # x = F.relu(self.l3(x))
        # x = F.relu(self.l4(x))
        return self.l2(x)

transform = transforms.Compose([
    transforms.ToTensor(),  # 转张量，将值缩放到[0,1]之间
    transforms.Normalize((0.1307,),(0.3081,))  # 归一化，第一个为均值，第二个为方差
])


train_dataset = datasets.MNIST(root= "/home/zhangzhongwang/data/saddle_points/MNIST/mnist",
                              train=True,  # 下载训练集
                              transform=transform,  # 转张量，将值缩放到[0,1]之间.也可以写成transform = transforms.ToTensor()
                              download=True
                              )

test_dataset = datasets.MNIST(root= "/home/zhangzhongwang/data/saddle_points/MNIST/mnist",
                              train=False,  # 下载训练集
                              transform=transform,  # 转张量，将值缩放到[0,1]之间
                           download=True)
# batch_size=32
train_dataset0=[]
train_target0=[]


indices = list(range(len(train_dataset)))
train_indices = indices[:1000]
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
train_loader = DataLoader(
    train_dataset, batch_size=int(batch_size),  num_workers=16, sampler=train_sampler)

test_loader = DataLoader(dataset=test_dataset,
                            batch_size=int(batch_size),
                            shuffle=False, num_workers=16)
train_loader = list(train_loader)
test_loader = list(test_loader)



path='/home/zhangzhongwang/data/loss_landscape/test97_retrain/500/6.0/32_1.0/198933'
mkdir('%s/pic1'%(path))
mkdir('%s/pic2'%(path))
# for iii in range(2000):

iii=50
m=400
t=3
device = torch.device("cuda:%s" % (1) if torch.cuda.is_available() else "cpu")
model = Net(m,t).to(device)
# PATH='/home/dir/data/saddle_points/test88/100/3.0/12441'
load_dir='/home/zhangzhongwang/data/loss_landscape/test97_retrain/500/6.0/32_1.0/198933/model_fin.pkl'
Path1=torch.load(load_dir,map_location='cuda:1')
model.load_state_dict(Path1)
weight_new=torch.cat((model.l1.weight,model.l1.bias.unsqueeze(1)),1)
print(weight_new.shape)
tensor_all=[]
for i in range(m):
    tensor_ori=weight_new[i,:]
    # print(tensor_ori.shape)#[785]
    tensor_ori=tensor_ori/torch.norm(tensor_ori)
    tensor_all.append(tensor_ori.detach().cpu().numpy())
    # print(torch.norm(tensor_ori))#1

# ori=np.zeros((m,m))

# np.corrcoef()
# for i in range(m):
#     print(i)
#     for j in range(m):
#         if j>i-1:
#             # print((tensor_all[i]*tensor_all[j]).shape)
#             ori[i,j]=sum(tensor_all[i]*tensor_all[j]).item()
# ori=ori+ori.transpose()
# print(torch.sqrt(torch.sum(model.l2.weight**2,axis=0)).shape)
W_lenth = torch.reshape(torch.sqrt(torch.sum(weight_new**2,axis=1)),(m,1))
A_lenth= torch.reshape(torch.sqrt(torch.sum(model.l2.weight**2,axis=0)),(m,1))
W_now_lenth=torch.multiply(W_lenth,A_lenth)

large_index=[]
delate=[]
for ind,i in enumerate(W_now_lenth):
    if i>0.01:
        large_index.append(ind)
    else:
        delate.append(ind)

print(len(large_index))

print(np.array(tensor_all).shape)
ori=np.matmul(np.array(tensor_all)[large_index,:],((np.array(tensor_all).transpose())[:,large_index]))
# ori=np.matmul(np.array(tensor_all),((np.array(tensor_all).transpose())))
print(ori.shape)
print(ori.shape)

# np.savetxt('%s/ori_58.txt'%(path),ori)
cos_distance_matrix_temp=ori
order = []
size = len(large_index)
order1 = range(size)
k = 0
order_temp = []
order1 = []
prune_index=[]
for j in range(size):
    if j != 0:
        for i in order2:
            if cos_distance_matrix_temp[k][i] > 0.9:
                order.append(i)
            else:
                order1.append(i)
    else:
        for i in range(size):
            if cos_distance_matrix_temp[0][i] > 0.9:
                order.append(i)
            else:
                order1.append(i)
                

    order_temp = order_temp + order
    # print(order)
    # if len(order)>5:
    prune_index.append(order)
    np.savetxt('%s/order%s.txt'%(path,j),weight_new[order,:].detach().cpu().numpy().transpose())
    # print(order_temp)
    # print(len(order_temp))
    if len(order_temp) == size:
        break
    k = order1[0]
    order2 = order1
    order1 = []
    order = []
print(prune_index)
prune_index.sort(key = lambda i:len(i),reverse=True) 
print(prune_index)
prune_index_all = []
for res in prune_index:
    prune_index_all.extend(res)
index_291=(np.array(large_index)[prune_index_all]).tolist()
print(index_291)
index_all=index_291+delate
print(len(index_all))

ori=np.matmul(np.array(tensor_all),((np.array(tensor_all).transpose())))
cos_distance_matrix_temp=ori
# np.savetxt('%s/temp.txt'%(path),order_temp)
# ratio_all=[]
# l2weight=Path1['l2.weight'][:,large_index]
# l1weight=Path1['l1.weight'][large_index,:]
# l1bias=Path1['l1.bias'][large_index]
# for i in prune_index:
#     ratio_1=[]
#     for ind,k in enumerate(i):

#         if ind==0:
#             norm_ori=torch.norm(torch.cat((l1weight[k,:],l1bias.unsqueeze(1)[k,:]),0))
#         else:
#             ratio=torch.norm(torch.cat((l1weight[k,:],l1bias.unsqueeze(1)[k,:]),0))/norm_ori
#             ratio_1.append(ratio.item())
#     ratio_all.append(ratio_1)
# # print(ratio_all)

cos_distance_matrix_temp = cos_distance_matrix_temp[index_all,:]
cos_distance_matrix_temp = cos_distance_matrix_temp[:,index_all]
# np.savetxt('%s/test.txt'%(path),cos_distance_matrix_temp)
# W_now_lenth = W_now_lenth[order_temp,:]

fig,ax = plt.subplots()
# ax = sns.heatmap(cos_distance_matrix_temp,linewidths = 0,vmin=-1,vmax=1,cmap='YlGnBu_r') # ,xticklabels = np.arange(40),yticklabels = np.arange(40))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
# ax.xaxis.set_ticks_position('top')
# ax.set_xticklabels(range(40),fontsize=5)
# ax.set_yticklabels(range(40),fontsize=5)
plt.imshow(cos_distance_matrix_temp,cmap='YlGnBu_r')
cb = plt.colorbar(ticks=[-1.0,0.0,1.0])
cb.ax.tick_params(labelsize=24)
# ax.xaxis.set_ticks_position('top')
plt.xlabel('index',fontsize=24)
plt.ylabel('index',fontsize=24)
plt.clim(-1, 1)
plt.yticks([0,100,200,300,400],size=24)
plt.xticks([0,100,200,300,400],size=24)
# plt.axhline(y=291,xmin=0,xmax=291/400,linestyle='--',color='black')
# plt.axvline(x=291,ymin=1-291/400,ymax=1,linestyle='--',color='black')
# plt.text(100,350,'low amplitude', size=22)
# plt.colorbar(fig, ) 
# plt.xticks([])  #去掉x轴
# plt.yticks([])  #去掉y轴
# ax.set_title("cos distance: tanh",fontsize=18)
plt.tight_layout()
fntmp = '%s/pic1/75000'%(path)
# save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)
plt.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
plt.savefig('%s.png'%(fntmp))


# fig,ax = plt.subplots()
# # ax = sns.heatmap(cos_distance_matrix_temp,linewidths = 0,vmin=-1,vmax=1,cmap='YlGnBu_r') # ,xticklabels = np.arange(40),yticklabels = np.arange(40))
# # ax.set_xticks(np.arange(40)) #设置x轴刻度
# # ax.set_yticks(np.arange(40)) #设置y轴刻度
# # ax.xaxis.set_ticks_position('top')
# # ax.set_xticklabels(range(40),fontsize=5)
# # ax.set_yticklabels(range(40),fontsize=5)
# plt.imshow(cos_distance_matrix_temp[291:,291:],cmap='YlGnBu_r')
# cb = plt.colorbar(ticks=[-1.0,0.0,1.0])
# cb.ax.tick_params(labelsize=24)
# # ax.xaxis.set_ticks_position('top')
# plt.xlabel('index',fontsize=24)
# plt.ylabel('index',fontsize=24)
# plt.clim(-1, 1)
# # plt.yticks([0,100,200,300,400],size=24)
# # plt.xticks([0,100,200,300,400],size=24)
# # plt.colorbar(fig, ) 
# # plt.xticks([])  #去掉x轴
# # plt.yticks([])  #去掉y轴
# # ax.set_title("cos distance: tanh",fontsize=18)
# plt.tight_layout()
# fntmp = '%s/pic1/75000_small'%(path)
# # save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)
# plt.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
# plt.savefig('%s.png'%(fntmp))
# plt.close()
# plt.clf()

# fig,ax = plt.subplots()
# plt.plot(np.floor(np.linspace(start=1,stop=size,endpoint=True,num=size)),W_now_lenth.detach().cpu().numpy())
# plt.xticks([])  #去掉x轴
# my_x_ticks = np.floor(np.linspace(start=1,stop=size,endpoint=True,num=size))
# plt.xticks(my_x_ticks)
# plt.xlabel(r'Index',fontsize=18)
# plt.ylabel(r'Amplitude',fontsize=18)
# plt.tick_params(axis='y',which='major',labelsize=14)
# plt.tick_params(axis='x',which='major',labelsize=7)
# my_x_ticks = [1,10,20,30,size]
# plt.xticks(my_x_ticks)
# ax.set_title("Amplitude: tanh",fontsize=18)
# plt.tight_layout()
# plt.savefig('%s/pic2/%s.png'%(path,iii*100))
# print('%s/pic2/%s.png'%(path,iii*100))
# # plt.close()
# # plt.clf()


# model_prune = Net(m,t).to(device)
# model_prune.load_state_dict(Path1)
# para_dict = model_prune.state_dict()

# print(prune_index)
# sum=0
# for i in prune_index:
#     sum+=len(i)
# print(sum)

# delate_index=[]

# for ind1,i in enumerate(prune_index):
#     # if not ind1==1:
#     #     continue
#     sum=l2weight[:,i[0]]
#     delate_index.extend(i[1:])
#     for ind2,k in enumerate(i):
#         if ind2==0:
#             continue
#         # print(ratio_all[ind1][ind2-1])
#         sum=sum+l2weight.data[:,k]*ratio_all[ind1][ind2-1]

#         l2weight[:,k]=torch.zeros_like(l2weight[:,k])
#     l2weight[:,i[0]]=sum
#     # if ind1==2:
#     #     break




# para_dict['l2.weight']=l2weight
# para_dict['l1.weight']=l1weight
# para_dict['l1.bias']=l1bias
# print(para_dict['l2.weight'][:,2])
# model_prune2=Net(len(large_index),t).to(device)
# model_prune2.load_state_dict(para_dict)

# loss_fn = torch.nn.MSELoss(reduction='mean')
# prune_loss=[]
# delta_prune_loss=[]
# norm_prune_loss=[]


# class My_loss(nn.Module):
#     def __init__(self):
#         super().__init__()   #没有需要保存的参数和状态信息
        
#     def forward(self, x, y):  # 定义前向的函数运算即可
#         return torch.mean(torch.pow((x - y)/y, 2))


# class My_norm(nn.Module):
#     def __init__(self):
#         super().__init__()   #没有需要保存的参数和状态信息
        
#     def forward(self, x):  # 定义前向的函数运算即可
#         return torch.mean(torch.pow(x, 2))

# myloss=My_loss()
# mynorm=My_norm()
# print(model.l2.weight[:,23])
# print(model_prune2.l2.weight[:,23])

# for batch_idx, dataall in enumerate(train_loader,1):
#     data, target=dataall
#     data, target = data.to(device), target.to(device)
#     inputs = data
#     outputs = model(inputs)
#     outputs_prune = model_prune2(inputs)
#     outputs=torch.nn.functional.softmax(outputs)
#     outputs_prune=torch.nn.functional.softmax(outputs_prune)
#     target_onehot=one_hot(target, 10).to(device)
#     loss = loss_fn(outputs, target_onehot)
#     loss_prune=loss_fn(outputs, outputs_prune)
#     loss_prune_new=myloss(outputs, outputs_prune)
#     norm=mynorm(outputs)
#     prune_loss.append((loss_prune).item())
#     delta_prune_loss.append(loss_prune_new.item())
#     norm_prune_loss.append((loss_prune/norm).item())

# np.savetxt('%s/prune_loss.txt'%(path),prune_loss)
# np.savetxt('%s/delta_prune_loss.txt'%(path),delta_prune_loss)
# np.savetxt('%s/norm_prune_loss.txt'%(path),norm_prune_loss)


###figure
# plt.rcParams['savefig.dpi'] = 200 #图片像素
# plt.rcParams['figure.dpi'] = 200 #分辨率




# ###print_new
# aaa=np.linspace(0,290,291)
# print(aaa)
# aaa=aaa.tolist()
# aaa_new=[]
# for i in aaa:
#     if i not in delate_index:
#         aaa_new.append(int(i))
#         print(int(i))
# print(aaa_new)
# print(model_prune2)

# para_dict_new={}
# para_dict_new['l1.weight']=model_prune2.l1.weight[aaa_new,:]
# para_dict_new['l1.bias']=model_prune2.l1.bias[aaa_new]
# para_dict_new['l2.weight']=model_prune2.l2.weight[:,aaa_new]

# torch.save(para_dict_new,'%s/model_new.ckpt'%(path))

# weight_new=torch.cat((model_prune2.l1.weight,model_prune2.l1.bias.unsqueeze(1)),1)
# print(weight_new.shape)
# print(weight_new.shape)
# tensor_all=[]
# for i in range(291):
#     tensor_ori=weight_new[i,:]
#     # print(tensor_ori.shape)#[785]
#     tensor_ori=tensor_ori/torch.norm(tensor_ori)
#     tensor_all.append(tensor_ori.detach().cpu().numpy())
# print(np.array(tensor_all).shape)

# ori=np.matmul(np.array(tensor_all)[aaa_new,:],((np.array(tensor_all).transpose())[:,aaa_new]))
# print(ori.shape)

# np.savetxt('%s/ori.txt'%(path),ori)
# cos_distance_matrix_temp=ori
# order = []
# size = 58
# order1 = range(size)
# k = 0
# order_temp = []
# order1 = []
# prune_index=[]
# for j in range(size):
#     if j != 0:
#         for i in order2:
#             if cos_distance_matrix_temp[k][i] > 0.5:
#                 order.append(i)
#             else:
#                 order1.append(i)
#     else:
#         for i in range(size):
#             if cos_distance_matrix_temp[0][i] > 0.5:
#                 order.append(i)
#             else:
#                 order1.append(i)
                

#     order_temp = order_temp + order
#     if len(order_temp) == size:
#         break
#     k = order1[0]
#     order2 = order1
#     order1 = []
#     order = []
# cos_distance_matrix_temp = cos_distance_matrix_temp[order_temp,:]
# cos_distance_matrix_temp = cos_distance_matrix_temp[:,order_temp]
# fig,ax = plt.subplots()
# # ax = sns.heatmap(cos_distance_matrix_temp,linewidths = 0,vmin=-1,vmax=1,cmap='YlGnBu_r') # ,xticklabels = np.arange(40),yticklabels = np.arange(40))
# # ax.set_xticks(np.arange(40)) #设置x轴刻度
# # ax.set_yticks(np.arange(40)) #设置y轴刻度
# # ax.xaxis.set_ticks_position('top')
# # ax.set_xticklabels(range(40),fontsize=5)
# # ax.set_yticklabels(range(40),fontsize=5)
# plt.imshow(cos_distance_matrix_temp,cmap='YlGnBu_r')
# cb = plt.colorbar(ticks=[-1.0,0.0,1.0])
# cb.ax.tick_params(labelsize=24)
# # ax.xaxis.set_ticks_position('top')
# plt.xlabel('index',fontsize=24)
# plt.ylabel('index',fontsize=24)
# plt.clim(-1, 1)
# # plt.yticks([0,40,80],size=24)
# # plt.xticks([0,40,80],size=24)
# # plt.colorbar(fig, ) 
# # plt.xticks([])  #去掉x轴
# # plt.yticks([])  #去掉y轴
# # ax.set_title("cos distance: tanh",fontsize=18)
# plt.tight_layout()
# fntmp = '%s/pic1/75000_new'%(path)
# # save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)
# plt.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
# plt.savefig('%s.png'%(fntmp))