import os, sys
import time
import pickle
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
import math
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
import platform
import shutil

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
Leftp = 0.18
Bottomp = 0.18
Widthp = 0.88 - Leftp
Heightp = 0.9 - Bottomp
pos = [Leftp, Bottomp, Widthp, Heightp]


def mkdir(fn):  # Create a directory
    if not os.path.isdir(fn):
        os.mkdir(fn)


def save_fig(pltm, fntmp, fp=0, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        fp.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()


# All parameters
R = {}

R['input_dim'] = 1
R['output_dim'] = 1
R['ActFuc'] = 1  # 0: ReLU; 1: Tanh; 2:Sin; 3:x**50; 4:Sigmoid
R['hidden_units'] = [3]

R['learning_rate'] = 5e-3
R['learning_rateDecay'] = 5e-8

plot_epoch = 1000
# R['train_size'] = 10;

R['test_size'] = 100
R['x_start'] = -15
R['x_end'] = 15
R['device'] = "0"
R['asi'] = 0
R['tuning_points'] = []
R['check_epoch'] = 10  # find the tuning point
R['tuning_ind'] = []
Ry = {}
Ry['y_all'] = []
Rw = {}
Rw['weight_R'] = []
lenarg = np.shape(sys.argv)[
    0]  # Sys.argv[ ]其实就是一个列表，里边的项为用户输入的参数，关键就是要明白这参数是从程序外部输入的，而非代码本身的什么地方，要想看到它的效果就应该将程序保存了，从外部来运行程序并给出参数。
if lenarg > 1:
    ilen = 1
    while ilen < lenarg:
        if sys.argv[ilen] == '-m':
            R['hidden_units'] = [np.int32(sys.argv[ilen + 1])]
        if sys.argv[ilen] == '-g':
            R['device'] = np.int32(sys.argv[ilen + 1])
        if sys.argv[ilen]=='-t':
            R['times']=np.float32(sys.argv[ilen+1])
        if sys.argv[ilen]=='-s':
            R['train_size']=np.int32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-lr':
        #     R['learning_rate']=np.float32(sys.argv[ilen+1])
        # if sys.argv[ilen]=='-dir':
        #     sBaseDir=sys.argv[ilen+1]
        ilen = ilen + 2
R['batch_size'] = R['train_size']
R['astddev'] = 1 / (R['hidden_units'] ** R['times'])
R['bstddev'] = 1 / (R['hidden_units'] ** R['times'])
# R['astddev'] = 0.5 * np.sqrt(1 / R['hidden_units'][0])  # For weight
# R['bstddev'] = np.sqrt(1 / R['hidden_units'][0])  # For bias
R['full_net'] = [R['input_dim']] + R['hidden_units'] + [R['output_dim']]

if R['input_dim'] == 1:
    R['test_inputs'] = np.reshape(np.linspace(R['x_start'] - 0.5, R['x_end'] + 0.5, num=R['test_size'],
                                              endpoint=True), [R['test_size'], 1])
    R['train_inputs'] = np.reshape(np.linspace(R['x_start'], R['x_end'], num=R['train_size'],
                                               endpoint=True), [R['train_size'], 1])
else:
    R['test_inputs'] = np.random.rand(R['test_size'], R['input_dim']) * (R['x_end'] - R['x_start']) + R['x_start']
    R['train_inputs'] = np.random.rand(R['train_size'], R['input_dim']) * (R['x_end'] - R['x_start']) + R['x_start']


def ReLU(x):
    return np.tanh(x)


def get_y(xs,sampleNo):  # Function to fit
    tmp =   ReLU(xs) + ReLU(xs + 7) + ReLU((xs -7))
    # tmp = np.sin(xs*math.pi)
    # tmp = np.random.normal(0, 0.5, sampleNo)
    # tmp = 0
    # for ii in range(R['input_dim']):
    #     tmp += np.cos(4*xs[:,ii:ii+1])
    return tmp


# R['train_inputs']=np.asarray([[-1],[-1/2],[1/4],[3/4],[1]])
test_inputs = R['test_inputs']
train_inputs = R['train_inputs']
# R['y_true_test'] = get_y(test_inputs)
# R['y_true_train'] = np.asarray([[1],[1/2],[1/4],[3/4],[5/6]])
R['y_true_train'] = get_y(R['train_inputs'],R['train_size'])
# Make a folder to save all output
neu_ind_folder = '%s' % (R['hidden_units'][0])
example_folder = 'test73'
sBaseDir0 = '/home/dir/data/loss_landscape'
sBaseDir = sBaseDir0 + '/' + example_folder
# BaseDir = '../../../nn/fitnd/'
if platform.system() == 'Windows':
    # device_n="0"
    BaseDir0 = '../../../nn/%s' % (sBaseDir0)
    # BaseDir = '../../../nn/%s'%(sBaseDir)
else:
    # device_n="0"
    BaseDir0 = sBaseDir0
    # BaseDir = sBaseDir
    matplotlib.use('Agg')
mkdir(BaseDir0)
BaseDir = '%s/%s' % (BaseDir0, example_folder)
mkdir(BaseDir)
BaseDir_a='%s/%s' % (BaseDir, R['times'])
mkdir(BaseDir_a)
BaseDir_neu = '%s/%s' % (BaseDir_a, neu_ind_folder)
mkdir(BaseDir_neu)
subFolderName = '%s' % (int(np.absolute(np.random.normal([1]) * 100000)) // int(1))
FolderName = '%s/%s/' % (BaseDir_neu, subFolderName)
mkdir(FolderName)

# mkdir('%smodel/'%(FolderName))
# print(subFolderName)

if not platform.system() == 'Windows':
    shutil.copy(__file__, '%s%s' % (FolderName, os.path.basename(__file__)))

device = torch.device("cuda:%s" % (R['device']) if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(device)


def weights_init(m):  # Initialization weight
    if isinstance(m, nn.Linear):
        m.weight.data.normal_(0, R['astddev'])
        m.bias.data.normal_(0, R['bstddev'])


class Act_op(nn.Module):  # Custom activation function
    def __init__(self):
        super(Act_op, self).__init__()

    def forward(self, x):
        # return x ** 50  # or F.relu(x) * F.relu(1-x)
        return (F.relu(x))**3


def getWini(hidden_units=[10, 20, 40], input_dim=1, output_dim_final=1, astddev=0.05, bstddev=0.05):
    hidden_num = len(hidden_units)
    # print(hidden_num)
    add_hidden = [input_dim] + hidden_units

    w_Univ0 = []
    b_Univ0 = []

    for i in range(hidden_num):
        input_dim = add_hidden[i]
        output_dim = add_hidden[i + 1]
        ua_w = np.float32(np.random.normal(loc=0.0, scale=astddev, size=[input_dim, output_dim]))
        ua_b = np.float32(np.random.normal(loc=0.0, scale=bstddev, size=[output_dim]))
        w_Univ0.append(np.transpose(ua_w))
        b_Univ0.append(np.transpose(ua_b))
    ua_w = np.float32(np.random.normal(loc=0.0, scale=astddev, size=[hidden_units[hidden_num - 1], output_dim_final]))
    ua_b = np.float32(np.random.normal(loc=0.0, scale=bstddev, size=[output_dim_final]))
    w_Univ0.append(np.transpose(ua_w))
    b_Univ0.append(np.transpose(ua_b))
    return w_Univ0, b_Univ0


R['n_fixed'] = 0
#
# R['W_fixed'] = []
# w_in = np.array([[0.33075222], [0.45510107]])
# b_h = np.array([0 - 0.20241769, -0.32820928])
# W_out = np.array([[0.4650716, 0.0429663]])
# b_out=np.array([-0.3])
# [ 0.33075222  0.45510107 -0.20241769 -0.32820928  0.4650716   0.0429663 ]
# w_in=np.float32(np.random.normal(loc=0.0,scale=0.4,size=[R['n_fixed'],1]))
# W_out=np.float32(np.random.normal(loc=0.0,scale=0.4,size=[1,R['n_fixed']]))
# b_h=np.float32(np.random.normal(loc=0.0,scale=0.4,size=[R['n_fixed']]))
# b_out=np.float32(np.random.normal(loc=0.0,scale=0.4,size=[1]))


# R['W_fixed'].append(w_in)
# R['W_fixed'].append(W_out)
# R['b_fixed'] = []
# R['b_fixed'].append(b_h)
# R['b_fixed'].append(b_out)
min_n = np.min([R['n_fixed'], R['hidden_units'][0]])
R['n_fixed'] = min_n
# print(min_n)
w_Univ0, b_Univ0 = getWini(hidden_units=R['hidden_units'], input_dim=R['input_dim'], output_dim_final=R['output_dim'],
                           astddev=R['astddev'], bstddev=R['bstddev'])
# w_Univ0[0]=R['W_fixed'][0]
# w_Univ0[1]=R['W_fixed'][1]
# b_Univ0[0]=R['b_fixed'][0]

# w_Univ0[0][0:min_n, :] = R['W_fixed'][0][0:min_n, :]
# w_Univ0[1][:, 0:min_n] = R['W_fixed'][1][:, 0:min_n]
# b_Univ0[0][0:min_n] = R['b_fixed'][0][0:min_n]
# b_Univ0[1]=R['b_fixed'][1]

R['deri']=[]
print(np.shape(w_Univ0[0]))
print(np.shape(b_Univ0[0]))
mkdir('%soutput/'%(FolderName))

class Network(nn.Module):  # DNN 0: ReLU; 1: Tanh; 2:Sin; 3:x**50; 4:Sigmoid
    def __init__(self):
        super(Network, self).__init__()
        self.block3 = nn.Sequential()
        self.block = nn.Sequential()
        for i in range(len(R['full_net']) - 2):
            d_linear = nn.Linear(R['full_net'][i], R['full_net'][i + 1])
            print('weight1: start')
            print(np.shape(d_linear.weight.data.numpy()))
            print('weight1: end')
            d_linear.weight.data = torch.nn.Parameter(torch.tensor(w_Univ0[i]))
            d_linear.bias.data = torch.nn.Parameter(torch.tensor(b_Univ0[i]))
            # print(d_linear.weight)
            print('weight2: start')
            print(np.shape(d_linear.weight.data.numpy()))
            print('weight2: end')
            # d_linear.weight.data = torch.tensor(w_Univ0[i])
            # d_linear.bias.data = torch.tensor(b_Univ0[i])
            self.block3.add_module('linear' + str(i), d_linear)

            self.block.add_module('linear' + str(i), d_linear)
            if R['ActFuc'] == 0:
                self.block.add_module('relu' + str(i), nn.ReLU())
                self.block3.add_module('relu' + str(i), nn.ReLU())
            elif R['ActFuc'] == 1:
                self.block.add_module('tanh' + str(i), nn.Tanh())
                self.block3.add_module('tanh' + str(i), nn.Tanh())
            elif R['ActFuc'] == 3:
                self.block.add_module('relu3' + str(i), Act_op())
                self.block3.add_module('relu3' + str(i), Act_op())
        i = len(R['full_net']) - 2
        d_linear = nn.Linear(R['full_net'][i], R['full_net'][i + 1], bias=False)
        d_linear.weight.data = torch.nn.Parameter(torch.tensor(w_Univ0[i]))
        # d_linear.bias.data = torch.nn.Parameter(torch.tensor(b_Univ0[i]))
        self.block.add_module('linear' + str(i), d_linear)
        if R['asi']:
            self.block2 = nn.Sequential()
            for i in range(len(R['full_net']) - 2):
                d_linear = nn.Linear(R['full_net'][i], R['full_net'][i + 1])
                print('weight1: start')
                print(np.shape(d_linear.weight.data.numpy()))
                print('weight1: end')
                d_linear.weight.data = torch.nn.Parameter(torch.tensor(w_Univ0[i]))
                d_linear.bias.data = torch.nn.Parameter(torch.tensor(b_Univ0[i]))
                # print(d_linear.weight)
                print('weight2: start')
                print(np.shape(d_linear.weight.data.numpy()))
                print('weight2: end')
                # d_linear.weight.data = torch.tensor(w_Univ0[i])
                # d_linear.bias.data = torch.tensor(b_Univ0[i])
                self.block2.add_module('linear2' + str(i), d_linear)
                if R['ActFuc'] == 0:
                    self.block2.add_module('relu2' + str(i), nn.ReLU())
                elif R['ActFuc'] == 1:
                    self.block2.add_module('tanh2' + str(i), nn.Tanh())
                elif R['ActFuc'] == 2:
                    self.block2.add_module('sin2' + str(i), nn.sin())
                elif R['ActFuc'] == 3:
                    self.block2.add_module('**502' + str(i), Act_op())
                elif R['ActFuc'] == 4:
                    self.block2.add_module('sigmoid2' + str(i), nn.sigmoid())
            i = len(R['full_net']) - 2
            d_linear = nn.Linear(R['full_net'][i], R['full_net'][i + 1], bias=False)
            d_linear.weight.data = torch.nn.Parameter(torch.tensor(-w_Univ0[i]))
            d_linear.bias.data = torch.nn.Parameter(torch.tensor(-b_Univ0[i]))
            self.block2.add_module('linear2' + str(i), d_linear)

            # print(self.block)

    def forward(self, x):
        if R['asi']:
            out = self.block(x) + self.block2(x)
        else:
            out = self.block(x)
        return out

    def hidden(self, x):
        out = self.block3(x)
        return out


class Model():
    def __init__(self):

        # y_train = net_(torch.FloatTensor(train_inputs).to(device))
        y_train = net_(torch.FloatTensor(train_inputs).to(device))
        loss_train = float(criterion(y_train.cpu(), torch.FloatTensor(R['y_true_train'])).cpu())
        y_test = net_(torch.FloatTensor(test_inputs).to(device))
        # loss_test = float(criterion(y_test.cpu(), torch.FloatTensor(R['y_true_test'])).cpu())

        nametmp = '%smodel/' % (FolderName)
        mkdir(nametmp)
        torch.save(net_.state_dict(), "%smodel.ckpt" % (nametmp))

        R['y_train'] = y_train.cpu().detach().numpy()
        R['y_test'] = y_test.cpu().detach().numpy()
        self.record_weight()

        # y_train = net_(torch.FloatTensor(train_inputs))
        # loss_train = float(criterion(y_train, torch.FloatTensor(R['y_true_train'])))
        # y_test = net_(torch.FloatTensor(test_inputs))
        # loss_test = float(criterion(y_test, torch.FloatTensor(R['y_true_test'])))

        # nametmp = '%smodel/'%(FolderName)
        # mkdir(nametmp)
        # torch.save(net_.state_dict(),"%smodel.ckpt"%(nametmp))

        # R['y_train'] = y_train.detach().numpy()
        # R['y_test'] = y_test.detach().numpy()


        R['loss_train'] = [loss_train]
        # R['loss_test'] = [loss_test]
        # R['max_gap_train'] = [np.max(np.abs(R['y_train']-R['y_true_train']))]
        # R['max_gap_test'] = [np.max(np.abs(R['y_test']-R['y_true_test']))]
        # R['mean_gap_train'] = [np.mean(np.abs(R['y_train']-R['y_true_train']))]
        # R['mean_gap_test'] = [np.mean(np.abs(R['y_test']-R['y_true_test']))]

    def run_onestep(self):

        y_test = net_(torch.FloatTensor(test_inputs).to(device))
        # loss_test = float(criterion(y_test, torch.FloatTensor(R['y_true_test']).to(device)).cpu())
        y_train = net_(torch.FloatTensor(train_inputs).to(device))
        loss_train = float(criterion(y_train, torch.FloatTensor(R['y_true_train']).to(device)).cpu())

        R['y_train'] = y_train.cpu().detach().numpy()
        R['y_test'] = y_test.cpu().detach().numpy()

        # y_test = net_(torch.FloatTensor(test_inputs))
        # loss_test = float(criterion(y_test, torch.FloatTensor(R['y_true_test'])))
        # y_train = net_(torch.FloatTensor(train_inputs))
        # loss_train = float(criterion(y_train, torch.FloatTensor(R['y_true_train']) ))

        # R['y_train'] = y_train.detach().numpy()
        # R['y_test'] = y_test.detach().numpy()



        # R['loss_test'].append(loss_test)
        R['loss_train'].append(loss_train)
        # R['max_gap_train'].append(np.max(np.abs(R['y_train']-R['y_true_train'])))
        # R['max_gap_test'].append(np.max(np.abs(R['y_test']-R['y_true_test'])))
        # R['mean_gap_train'].append(np.mean(np.abs(R['y_train']-R['y_true_train'])))
        # R['mean_gap_test'].append(np.mean(np.abs(R['y_test']-R['y_true_test'])))

        optimizer = torch.optim.SGD(net_.parameters(), lr=R['learning_rate'])

        for i in range(R['train_size'] // R['batch_size'] + 1):  # bootstrap

            mask = np.random.choice(R['train_size'], R['batch_size'], replace=False)
            y_train = net_(torch.FloatTensor(train_inputs[mask]).to(device))
            loss = criterion(y_train, torch.FloatTensor(R['y_true_train'][mask]).to(device))

            # mask = np.random.choice(R['train_size'], R['batch_size'], replace=False)
            # y_train = net_(torch.FloatTensor(train_inputs[mask]))
            # loss = criterion(y_train, torch.FloatTensor(R['y_true_train'][mask]))




            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            network_param=net_.state_dict()
            param_list = [param for param in network_param]
            first_derivative = torch.autograd.grad(loss, param_list)
            R['deri'].appen(torch.norm(first_derivative).item())

        R['learning_rate'] = R['learning_rate'] * (1 - R['learning_rateDecay'])


    def record_weight(self):
        if R['hidden_units'][0] == 1:
            tmp_w1 = np.squeeze(net_.block[0].weight.cpu().detach().numpy())
            tmp_b1 = np.squeeze(net_.block[0].bias.cpu().detach().numpy())
            tmp_w2 = np.squeeze(net_.block[2].weight.cpu().detach().numpy())
            tmp_w = [tmp_w1, tmp_b1, tmp_w2]
        else:
            tmp_w1 = np.squeeze(net_.block[0].weight.cpu().detach().numpy())
            tmp_b1 = np.squeeze(net_.block[0].bias.cpu().detach().numpy())
            tmp_w2 = np.squeeze(net_.block[2].weight.cpu().detach().numpy())
            # tmp_w2=np.squeeze(net_.block[2].weight.cpu().detach().numpy())[0:R['n_fixed']]
            tmp_w = np.concatenate((tmp_w1, tmp_b1, tmp_w2), axis=0)
        Rw['weight_R'].append(tmp_w)

    def run(self, step_n=1):

        # Load paremeters
        nametmp = '%smodel/model.ckpt' % (FolderName)
        net_.load_state_dict(torch.load(nametmp))
        net_.eval()

        for epoch in range(step_n):

            self.run_onestep()
            self.record_weight()
            Ry['y_all'].append(R['y_test'])
            # if epoch%R['check_epoch']==0 and epoch>400:
            # used_epoch=200
            # t_x = np.log10(R['loss_train'][-used_epoch:])
            # t_y = np.log10(np.arange(epoch-used_epoch,epoch))
            # coe=np.polyfit(t_x[:-140],t_y[:-140],deg=1)
            # end_v=coe[0]*t_x[-1]+coe[1]
            # if end_v-t_y[-1]<-2/epoch:
            #     #end_count++  ### count three times
            # R['tuning_points'].append(epoch)

            if epoch % 1000 == 0:

                # if len(R['tuning_points'])>2:
                # ind= np.int32(np.asarray(R['tuning_points']) )
                # ind_diff=np.diff(np.diff(ind))
                # ind2 = np.where(abs(ind_diff)>0)[0]

                # ind3 = []
                # ind3.append(ind[0])
                # if len(ind2)>0:
                #     for i_ind in np.arange(len(ind2)):
                #         if i_ind%2==1:
                #             continue
                #         ind3.append(ind[ind2[i_ind]+1])
                #         ind3.append(ind[ind2[i_ind]+2])
                # ind3.append(ind[-1])
                # # R['tuning_ind']=ind3
                # ind4=[]
                # for i_ind in range(len(ind3)-1):
                #     if i_ind%2==1:
                #         continue
                #     ind4.append(np.int32(10**((np.log10(ind3[i_ind])+np.log10(ind3[i_ind+1]))/2)))
                # R['tuning_ind']=ind4
                # R['y_tuning']=[Ry['y_all'][tmp] for tmp in R['tuning_ind']]
                # R['loss_tuning']=[R['loss_train'][tmp] for tmp in R['tuning_ind']]
                # self.plot_tuning()

                print('time elapse: %.3f' % (time.time() - t0))
                print('model, epoch: %d, train loss: %f' % (epoch, R['loss_train'][-1]))
                # print('model, epoch: %d, test loss: %f' %(epoch,R['loss_test'][-1]))
                # print('max gap of train inputs: %f' %(R['max_gap_train'][-1]))
                # print('max gap of test inputs: %f' %(R['max_gap_train'][-1]))
                # print('mean gap of train inputs: %f' %(R['mean_gap_train'][-1]))
                # print('mean gap of test inputs: %f' %(R['mean_gap_train'][-1]))
                # self.plot_weight()
            if epoch % plot_epoch == 0:
                y_hid = net_.hidden(torch.FloatTensor(test_inputs).to(device))
                output_weight = net_.block[-1].weight.cpu().detach().numpy()[0]
                R['y_hid'] = y_hid.cpu().detach().numpy()
                sh_y_hid = np.shape(R['y_hid'])[1]
                # print(sh_y_hid)
                subfig_row = np.int32(np.floor(np.sqrt(sh_y_hid)))
                subfig_col = np.int32(np.ceil(sh_y_hid / subfig_row))
                # print(subfig_row)
                # print(subfig_col)
                # plt.show()
                # plt.figure()
                # # ax = plt.subplot(5,6,2)
                # # ax.plot(test_inputs,R['y_hid'][:,0],'k')
                # for i_sub in range(sh_y_hid):
                #     # print(i_sub)
                #     ax = plt.subplot(subfig_row, subfig_col, i_sub + 1)
                #
                #     ax.plot(test_inputs, R['y_hid'][:, i_sub], 'r')
                #     y2 = R['y_true_train']
                #     ax.plot(train_inputs, y2, 'k.', label='True')
                #     ax.axis('off')
                #     ax.text(-0.5, 0.25, '%.2f' % (output_weight[i_sub]))
                #     ax.set_xlim([-1.5, 1.5])
                #
                # # ax.set_xscale('log')
                # # ax.set_yscale('log')
                # # plt.legend(fontsize=18)
                # # plt.title('loss',fontsize=15)
                # # fntmp = '%shiddeny%s'%(FolderName,epoch)
                # fntmp = '%shiddeny' % (FolderName)
                # save_fig(plt, fntmp, iseps=0)

                self.plot_loss()
                self.plot_y(epoch)
                # self.plot_gap()
                self.save_file()


                # #Save parameters
                # nametmp = '%smodel/'%(FolderName)
                # shutil.rmtree(nametmp)
                # mkdir(nametmp)
                # torch.save(net_.state_dict(), "%smodel.ckpt"%(nametmp))

            # if R['loss_train'][-1] < 1e-5:
            #     break

    def plot_weight(self):
        weight_R = np.stack(Rw['weight_R'])
        plt.figure()
        for i_sub in range(R['n_fixed']):
            # print(i_sub)
            for ji in range(3):
                # print('%s'%(3*i_sub+ji))
                ax = plt.subplot(R['n_fixed'], 3, 3 * i_sub + ji + 1)
                ax.plot(abs(weight_R[:, ji * R['n_fixed'] + i_sub]))
                plt.title('%s' % (3 * i_sub + ji))
                ax.set_xscale('log')
                ax.set_yscale('log')
                ax.set_ylim([5e-2, 1e1])
                # ax.axis('off')
                # ax.text(-0.5,1,'%.2f'%(output_weight[i_sub]))

        # ax.set_xscale('log')
        # ax.set_yscale('log')
        # plt.legend(fontsize=18)
        # plt.title('loss',fontsize=15)
        # fntmp = '%shiddeny%s'%(FolderName,epoch)
        fntmp = '%sweightevolve' % (FolderName)
        save_fig(plt, fntmp, iseps=0)

    def plot_loss(self):

        plt.figure()
        ax = plt.gca()
        # y1 = R['loss_test']
        y2 = np.asarray(R['loss_train'])
        # plt.plot(y1,'ro',label='Test')
        plt.plot(y2, 'k-', label='Train')
        if len(R['tuning_ind']) > 0:
            plt.plot(R['tuning_ind'], y2[R['tuning_ind']], 'r*')
        ax.set_xscale('log')
        ax.set_yscale('log')
        plt.legend(fontsize=18)
        plt.title('train_loss', fontsize=18)
        plt.tick_params(labelsize=18)
        fntmp = '%sloss' % (FolderName)
        save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

    def plot_tuning(self):
        plt.figure()
        ax = plt.gca()
        y2 = R['y_true_train']
        plt.plot(train_inputs, y2, 'b*', label='True')
        for iit in range(len(R['y_tuning'])):
            plt.plot(test_inputs, R['y_tuning'][iit], '-', label='%.3f' % (R['loss_tuning'][iit]))
        plt.title('turn points', fontsize=15)
        plt.legend(fontsize=18)
        fntmp = '%sturn' % (FolderName)
        save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

    def plot_y(self, epoch):

        if R['input_dim'] == 2:
            X = np.arange(R['x_start'], R['x_end'], 0.1)
            Y = np.arange(R['x_start'], R['x_end'], 0.1)
            X, Y = np.meshgrid(X, Y)
            xy = np.concatenate((np.reshape(X, [-1, 1]), np.reshape(Y, [-1, 1])), axis=1)
            Z = np.reshape(get_y(xy), [len(X), -1])

            fp = plt.figure()
            ax = fp.gca(projection='3d')
            surf = ax.plot_surface(X, Y, Z - np.min(Z), cmap=cm.coolwarm, linewidth=0, antialiased=False)
            ax.zaxis.set_major_locator(LinearLocator(5))
            ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
            fp.colorbar(surf, shrink=0.5, aspect=5)
            ax.scatter(train_inputs[:, 0], train_inputs[:, 1], R['y_train'] - np.min(R['y_train']))
            fntmp = '%s2du%s' % (FolderName, name)
            save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

        if R['input_dim'] == 1:
            plt.figure()
            ax = plt.gca()
            y1 = R['y_test']
            y2 = R['y_true_train']
            plt.plot(test_inputs, y1, 'r-', label='Test')
            plt.plot(train_inputs, y2, 'b*', label='True')
            plt.title('epoch=%s'%(epoch), fontsize=18)
            plt.legend(fontsize=18)
            plt.tick_params(labelsize=18)
            fntmp = '%soutput/u_m%s' % (FolderName, epoch)
            # fntmp = '%su_m%s' % (FolderName, '')
            save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)

    def save_file(self):
        with open('%s/objs.pkl' % (FolderName), 'wb') as f:
            pickle.dump(R, f, protocol=4)
        with open('%s/objsy.pkl'%(FolderName), 'wb') as f:
            pickle.dump(Ry, f, protocol=4)
        with open('%s/objsw.pkl'%(FolderName), 'wb') as f:
            pickle.dump(Rw, f, protocol=4)
        # with open('%s/objsw.pkl' % (FolderName), 'wb') as f:
        #     Rwini = {}
        #     Rwini['weight_R'] = []
        #     Rwini['weight_R'].append(Rw['weight_R'][0])
        #     Rwini['weight_R'].append(Rw['weight_R'][-1])
        #     pickle.dump(Rwini, f, protocol=4)

        text_file = open("%s/Output.txt" % (FolderName), "w")
        for para in R:
            if np.size(R[para]) > 20:
                continue
            text_file.write('%s: %s\n' % (para, R[para]))
        text_file.write('loss end: %s\n' % (R['loss_train'][-1]))
        text_file.write('weight ini: %s\n' % (Rw['weight_R'][0]))
        text_file.close()


t0 = time.time()
net_ = Network().to(device)
# net_.apply(weights_init)
print(net_)

criterion = nn.MSELoss(reduction='mean').to(device)
# optimizer = torch.optim.SGD(net_.parameters(), lr=R['learning_rate'])
# optimizer = torch.optim.Adam(net_.parameters(), lr=R['learn ing_rate'])

model = Model()
# model.run(600)
model.run(3000000)

saddle_position=[]

for ind,i in enumerate(R['deri']):
    if R['deri']<1e-9 :
        flag=0
        if saddle_position==[]:
            saddle_position.append(ind)
        else:
            for j in saddle_position:
                if abs(R['loss_train'][ind]-R['loss_train'][j])<1e-4:
                    flag=1
                    break
            if flag==0:
                saddle_position.append(ind)
        
np.savetxt('%s/saddle_pos.txt', FolderName)





# t0 = time.time()
# net_ = Network() 
# # net_.apply(weights_init)
# print(net_)

# criterion = nn.MSELoss(reduction='mean') 
# optimizer = torch.optim.Adam(net_.parameters(), lr=R['learning_rate'])

# model = Model()
# model.run(1000000)




