import os, sys
import time
import pickle
import warnings

warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
import torch.optim as optim
import torch.nn.functional as F
import matplotlib
import math
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from matplotlib import cm
import platform
import shutil

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
Leftp = 0.18
Bottomp = 0.18
Widthp = 0.88 - Leftp
Heightp = 0.9 - Bottomp
pos = [Leftp, Bottomp, Widthp, Heightp]


def mkdir(fn):  # Create a directory
    if not os.path.isdir(fn):
        os.mkdir(fn)


def save_fig(pltm, fntmp, fp=1, ax=0, isax=0, iseps=0, isShowPic=0):  # Save the figure
    if isax == 1:
        pltm.rc('xtick', labelsize=18)
        pltm.rc('ytick', labelsize=10)
        ax.set_position(pos, which='both')
    fnm = '%s.png' % (fntmp)
    pltm.savefig(fnm)
    if iseps:
        fnm = '%s.eps' % (fntmp)
        pltm.savefig(fnm, format='eps', dpi=600)
    if fp != 0:
        pltm.savefig("%s.pdf" % (fntmp), bbox_inches='tight')
    if isShowPic == 1:
        pltm.show()
    elif isShowPic == -1:
        return
    else:
        pltm.close()

path='/home/dir/data/loss_landscape/test72/6.0/1/100264/'
path1='/home/dir/data/loss_landscape/test72/6.0/1/100264/objsy.pkl'
path11='/home/dir/data/loss_landscape/test72/6.0/1/100264/objs.pkl'

# path='/home/dir/data/loss_landscape/test72/6.0/3/30391/'
# path1='/home/dir/data/loss_landscape/test72/6.0/3/30391/objsy.pkl'
# path11='/home/dir/data/loss_landscape/test72/6.0/3/30391/objs.pkl'

path2='/home/dir/data/loss_landscape/test73/0.8/500/9802/objsy.pkl'


with open(path11,'rb') as f:
    R1=pickle.load(f)


with open(path1,'rb') as f:
    Ry2=pickle.load(f)

with open(path2,'rb') as f:
    Ry1=pickle.load(f)


# Ry1['y_all'][-1]
test_inputs = R1['test_inputs']
train_inputs = R1['train_inputs']
mkdir('%soutput/' % (path))

plt.figure()
ax = plt.gca()
y1 = Ry2['y_all'][-1]
y11 = Ry1['y_all'][5000]
y2 = R1['y_true_train']


plt.plot(test_inputs, y11, 'r-', label='width-500',linewidth=3)
plt.plot(test_inputs, y1, 'k--', label='width-1',linewidth=3)

plt.plot(train_inputs, y2, 'b*', label='training data')
plt.xlabel('x',fontsize=18)
plt.ylabel('y',fontsize=18)
# plt.title('epoch=%s'%(epoch), fontsize=18)
plt.legend(fontsize=18,loc='best')
plt.tick_params(labelsize=18)
fntmp = '%soutput/aaa_5000' % (path)
# fntmp = '%su_m%s' % (FolderName, '')
save_fig(plt, fntmp, ax=ax, isax=1, iseps=0)