import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets  # 放置了许多常用数据集,包括手写数字识别
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader,Dataset,TensorDataset
# from tqdm import tqdm
import os, sys
from torch.nn import init

import os, sys
import time
import pickle
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
import math
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
import platform
import shutil


os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
Leftp = 0.18
Bottomp = 0.18
Widthp = 0.88 - Leftp
Heightp = 0.9 - Bottomp
pos = [Leftp, Bottomp, Widthp, Heightp]


def mkdir(fn):  # Create a directory
    if not os.path.isdir(fn):
        os.mkdir(fn)


def save_fig(pltm, fntmp, fp=0, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        fp.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()


def one_hot(x, class_count):
	return torch.eye(class_count)[x,:]

lenarg = np.shape(sys.argv)[
    0]  # Sys.argv[ ]其实就是一个列表，里边的项为用户输入的参数，关键就是要明白这参数是从程序外部输入的，而非代码本身的什么地方，要想看到它的效果就应该将程序保存了，从外部来运行程序并给出参数。
if lenarg > 1:
    ilen = 1
    while ilen < lenarg:
        if sys.argv[ilen] == '-m':
            m = np.int32(sys.argv[ilen + 1])
        if sys.argv[ilen] == '-g':
            d = np.int32(sys.argv[ilen + 1])
        if sys.argv[ilen]=='-t':
            t=np.float32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-s':
        #     R['train_size']=np.int32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-lr':
        #     R['learning_rate']=np.float32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-dir':
        #     sBaseDir=sys.argv[ilen+1]
        ilen = ilen + 2
path_ori='/home/dir/data/saddle_points/test97/'
mkdir(path_ori)
path1='%s%s/'%(path_ori,m)
mkdir(path1)
path='%s%s/'%(path1,t)
mkdir(path)
subFolderName = '%s' % (int(np.absolute(np.random.normal([1]) * 100000)) // int(1))
path='%s%s/'%(path,subFolderName)
mkdir(path)
if not platform.system() == 'Windows':
    shutil.copy(__file__, '%s%s' % (path, os.path.basename(__file__)))
transform = transforms.Compose([
    transforms.ToTensor(),  # 转张量，将值缩放到[0,1]之间
    transforms.Normalize((0.1307,),(0.3081,))  # 归一化，第一个为均值，第二个为方差
])


train_dataset = datasets.MNIST(root= "/home/dir/data/saddle_points/MNIST/mnist",
                              train=True,  # 下载训练集
                              transform=transform,  # 转张量，将值缩放到[0,1]之间.也可以写成transform = transforms.ToTensor()
                              download=True
                              )

test_dataset = datasets.MNIST(root= "/home/dir/data/saddle_points/MNIST/mnist",
                              train=False,  # 下载训练集
                              transform=transform,  # 转张量，将值缩放到[0,1]之间
                           download=True)
batch_size=60000
train_dataset0=[]
train_target0=[]

#
for i in range(60000):
    if len(train_dataset0)<1000:
        train_dataset0.append(train_dataset[i][0].unsqueeze(0))
        train_target0.append(train_dataset[i][1])
    else:
        break

train_dataset_new=torch.cat(train_dataset0,0)
train_label_new= torch.from_numpy(np.array(train_target0))


train_dataset2 = TensorDataset(train_dataset_new, train_label_new)
train_loader = DataLoader(train_dataset2, batch_size=batch_size, shuffle=True,num_workers=8)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=False,num_workers=8)
train_loader=list(train_loader)
test_loader=list(test_loader)
#
class Net(torch.nn.Module):
    def __init__(self,m,t):
        super(Net, self).__init__()
        self.l1 = torch.nn.Linear(784, m)
        init.normal_(self.l1.weight, 0, 1/m**(t))
        init.normal_(self.l1.bias, 0, 1/m**(t))
        self.l2 = torch.nn.Linear(m, 10,bias=False)
        init.normal_(self.l2.weight, 0, 1/m**(t))
        # init.normal_(self.l2.bias, 0, 1/m**(t))
        # self.l3 = torch.nn.Linear(m, m)
        # init.normal_(self.l3.weight, 0, 1/m**(t))
        # init.normal_(self.l3.bias, 0, 1/m**(t))
        # self.l4 = torch.nn.Linear(m, 10,bias=False)
        # init.normal_(self.l4.weight, 0, 1/m**(t))
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.l1(x))
        # x = F.relu(self.l2(x))
        # x = F.relu(self.l3(x))
        # x = F.relu(self.l3(x))
        # x = F.relu(self.l4(x))
        return self.l2(x)


# m=400
device = torch.device("cuda:%s" % (d) if torch.cuda.is_available() else "cpu")
model = Net(m,t).to(device)
torch.save(model.state_dict(), "%smodel.ckpt"%(path))
# loss_fn = torch.nn.CrossEntropyLoss()
loss_fn = torch.nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr= 0.0001)
lossall=[]
train_acc=[]
test_acc=[]
print(model)
def train(epoch):
    runing_loss = 0.0
    correct = 0
    total =0
    for batch_idx, dataall in enumerate(train_loader,1):
        data, target=dataall
        data, target = data.to(device), target.to(device)
        inputs = data
        # print(inputs.shape)
        optimizer.zero_grad()

        outputs = model(inputs)
        # print(outputs)
        outputs=torch.nn.functional.softmax(outputs)

        target_onehot=one_hot(target, 10).to(device)
        # print(target.shape)
        # print(outputs.shape)#[None,10]
        loss = loss_fn(outputs, target_onehot)
        loss.backward()
        optimizer.step()

        runing_loss += loss.item()
        # if batch_idx % 10 == 0:
        if epoch%100==0:
            print("[%d, %5d] loss: %.3f" % (epoch + 1, batch_idx + 1, runing_loss))
        lossall.append(runing_loss)
        runing_loss = 0.0
        _, predicted = torch.max(outputs.data, dim=1)  # 返回两个值，第一个是最大值，第二个是最大值的索引。dim=1表示在列维度求以上结果，dim = 0表示在行维度求以上结果。
        total += target.size(0)  # 每一个batch_size 中labels是一个（N，1）的元组，size(0)=N
        correct += (predicted == target).sum().item()
        train_acc.append(100*correct/total)

def test():
    correct = 0
    total =0
    with torch.no_grad():
        for data in test_loader:
            images, labels =data
            images=images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, dim =1 )  # 返回两个值，第一个是最大值，第二个是最大值的索引。dim=1表示在列维度求以上结果，dim = 0表示在行维度求以上结果。
            total += labels.size(0)  # 每一个batch_size 中labels是一个（N，1）的元组，size(0)=N
            correct +=(predicted == labels).sum().item()  # 对的总个数
    test_acc.append(100*correct/total)
    # print("Accuracy on the test set %d %%" % (100*correct/total))
    # print('cor:%s,tot:%s'%(correct,total))


if __name__ == "__main__":

    for epoch in range(500000):
        # print(epoch)
        train(epoch)
        if epoch % 1000 == 0:
            np.savetxt('%sloss.txt'%(path),lossall)
            np.savetxt('%strainacc.txt' % (path), train_acc)
            np.savetxt('%stestacc.txt' % (path), test_acc)
        if epoch % 10 == 0:
            test()
        if epoch % 1000 ==0:
            plt.figure()
            ax = plt.gca()
            # y1 = R['loss_test']
            # y2 = np.asarray(R['loss_train'])
            # plt.plot(y1,'ro',label='Test')
            plt.plot(lossall, 'k-', label='Train')
            # if len(R['tuning_ind']) > 0:
            #     plt.plot(R['tuning_ind'], y2[R['tuning_ind']], 'r*')
            ax.set_xscale('log')
            ax.set_yscale('log')
            # plt.legend(fontsize=18)
            plt.title('loss', fontsize=15)
            fntmp = '%sloss log' % (path)
            save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

            plt.figure()
            ax = plt.gca()
            # y1 = R['loss_test']
            # y2 = np.asarray(R['loss_train'])
            # plt.plot(y1,'ro',label='Test')
            plt.plot(lossall, 'k-', label='Train')
            # if len(R['tuning_ind']) > 0:
            #     plt.plot(R['tuning_ind'], y2[R['tuning_ind']], 'r*')
            # ax.set_xscale('log')
            ax.set_yscale('log')
            # plt.legend(fontsize=18)
            plt.title('loss', fontsize=15)
            fntmp = '%sloss' % (path)
            save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)


        if lossall[-1]<1e-5:
            np.savetxt('%sloss.txt' % (path), lossall)
            plt.figure()
            ax = plt.gca()
            # y1 = R['loss_test']
            # y2 = np.asarray(R['loss_train'])
            # plt.plot(y1,'ro',label='Test')
            plt.plot(lossall, 'k-', label='Train')
            # if len(R['tuning_ind']) > 0:
            #     plt.plot(R['tuning_ind'], y2[R['tuning_ind']], 'r*')
            # ax.set_xscale('log')
            ax.set_yscale('log')
            # plt.legend(fontsize=18)
            plt.title('loss', fontsize=15)
            fntmp = '%sloss' % (path)
            save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

            plt.figure()
            ax = plt.gca()
            # y1 = R['loss_test']
            # y2 = np.asarray(R['loss_train'])
            # plt.plot(y1,'ro',label='Test')
            plt.plot(lossall, 'k-', label='Train')
            # if len(R['tuning_ind']) > 0:
            #     plt.plot(R['tuning_ind'], y2[R['tuning_ind']], 'r*')
            ax.set_xscale('log')
            ax.set_yscale('log')
            # plt.legend(fontsize=18)
            plt.title('loss', fontsize=15)
            fntmp = '%sloss log' % (path)
            save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)
            break
#
