import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns

Gamma2 = ['-0.3','-0.2','-0.1','+0.0','+0.1','+0.2','+0.3']
Gamma3 = ['0.90','1.10','1.30','1.50','1.70','1.90','2.10']

QQ={}

with open(r"F:\deep_study\20210824_threelayers_limit_width_zqx\result/objk.pkl",'rb') as f:
    Q = pickle.load(f)

for i in Gamma2:
    for j in Gamma3:
        RD_w_0_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_0'])
        RD_w_1_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_1'])
        RD_w_0_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_0'])
        RD_w_1_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_1'])
        RD_w_0_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_0'])
        RD_w_1_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_1'])
        RD_w_0_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_0'])
        RD_w_1_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_1'])
        RD_w_0_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_0'])
        RD_w_1_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_1'])

        plt.figure()
        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
        #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
        #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
        plt.scatter([100,1000,2500,5000,10000],[RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000],label = 'training data', color='g')
        # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
        # plt.title('result',fontsize=15)        
        #ax.set_yscale('log')
        plt.xlabel(r'x',fontsize=24)
        plt.ylabel(r'y',rotation=0,fontsize=24)
        # my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
        # plt.xticks(my_x_ticks)
        # my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
        # plt.yticks(my_y_ticks)
        # plt.tick_params(axis='both',which='major',labelsize=20)
        # plt.legend()
        ax.set_xscale('log')
        ax.set_yscale('log')
        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        print(z1)
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
        Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]
        print(p1) #显示多项式
        plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line: '+str( round(z1[0], 4)),linestyle='--')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.legend()
        plt.tight_layout()
        plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/gamma2_%s_gamma3_%s_RD_w_0.png'%(i,j))
        plt.close()

        plt.figure()
        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
        #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
        #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
        plt.scatter([100,1000,2500,5000,10000],[RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000],label = 'training data', color='g')
        # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
        # plt.title('result',fontsize=15)        
        #ax.set_yscale('log')
        plt.xlabel(r'x',fontsize=24)
        plt.ylabel(r'y',rotation=0,fontsize=24)
        # my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
        # plt.xticks(my_x_ticks)
        # my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
        # plt.yticks(my_y_ticks)
        # plt.tick_params(axis='both',which='major',labelsize=20)
        # plt.legend()
        ax.set_xscale('log')
        ax.set_yscale('log')
        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        print(z1)
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
        Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]
        print(p1) #显示多项式
        plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line: '+str( round(z1[0], 4)),linestyle='--')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.legend()
        plt.tight_layout()
        plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/gamma2_%s_gamma3_%s_RD_w_1.png'%(i,j))
        plt.close()

print(QQ)

heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[len(Gamma2)-i-1][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_0']

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
ax.set_yticklabels(([-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3]),fontsize=12)
ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/S_w_0.png')
plt.close()
plt.clf()

stars_0 = []
heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))

for i in range(len(Gamma2)):
    if  Gamma2[i] == '-0.1' or Gamma2[i] == '+0.0':
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+2])+'_'+'k_0'])
        print(val)
        z1 = np.polyfit([1.3,1.5,1.7], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if Gamma2[i] == '-0.3' or Gamma2[i] == '-0.2':
        val = []
        for j in range(5):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+1])+'_'+'k_0'])
        print(val)
        z1 = np.polyfit([1.1,1.3,1.5,1.7,1.9], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if Gamma2[i] == '+0.1':
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+3])+'_'+'k_0'])
        z1 = np.polyfit([1.5,1.7,1.9], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if Gamma2[i] == '+0.2':
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+4])+'_'+'k_0'])
        z1 = np.polyfit([1.7,1.9,2.1], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if Gamma2[i] == '+0.3':
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+4])+'_'+'k_0'])
        z1 = np.polyfit([1.7,1.9,2.1], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    stars_0.append(-z1[1]/z1[0])
print(stars_0)

heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[i][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_0']

XX,YY = np.meshgrid([0.9,1.1,1.3,1.5,1.7,1.9,2.1],[-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3])

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,heatmap_k, cmap='RdBu',shading='auto',vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([0.9,1.1,1.3,1.5,1.7,1.9,2.1]) #设置x轴刻度
ax.set_yticks([-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im,ticks = [-0.6,0,0.6])# , cax=cax)

plt.plot(stars_0,[-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3],'k*',markersize=14)
plt.plot([1.5,2.2],[0,0.35],'k--',linewidth=1.5)
x_aux=[1.5,1.5]
y_aux=[0.0,-0.35]
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/S_w_0_xu.png')
plt.close()
plt.clf()


heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[len(Gamma2)-i-1][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_1']

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
ax.set_yticklabels(([-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3]),fontsize=12)
ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/S_w_1.png')
plt.close()
plt.clf()


stars_0 = []
heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    if Gamma2[i] == 0.5:
        val = []
        for j in range(9):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+2])+'_'+'k_1'])
        z1 = np.polyfit([2.30,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.70], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if Gamma2[i] == 0.4:
        val = []
        for j in range(9):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+4])+'_'+'k_1'])
        z1 = np.polyfit([2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if Gamma2[i] == 0.3:
        val = []
        for j in range(5):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+8])+'_'+'k_1'])
        z1 = np.polyfit([2.6,2.65,2.7,2.75,2.8], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if Gamma2[i] == 0.2:
        val = []
        for j in range(2):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+11])+'_'+'k_1'])
        z1 = np.polyfit([2.75,2.8], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    # stars_0.append(-z1[1]/z1[0])
print(stars_0)

heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[i][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_1']

XX,YY = np.meshgrid([0.9,1.1,1.3,1.5,1.7,1.9,2.1],[-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3])

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,heatmap_k, cmap='RdBu',shading='auto',vmin=-1.8,vmax=-0.0) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([0.9,1.1,1.3,1.5,1.7,1.9,2.1]) #设置x轴刻度
ax.set_yticks([-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.tick_params(labelsize=8)

plt.colorbar(im,ticks = [-1.2,-0.20])# , cax=cax)
# plt.plot(stars_0,[0.2,0.3,0.4,0.5,0.6,0.7,0.8],'k*',markersize=14)
# plt.plot([2.5,2.5],[0.5,0.85],'k--',linewidth=1.5)
# x_aux=[2.5,2.825]
# y_aux=[0.5,0.175]
# plt.plot(x_aux,y_aux,'k--',linewidth=1.5)

# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/S_w_1_xu.png')
plt.close()
plt.clf()
