import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns
from matplotlib.lines import Line2D
Gamma2 = ['0.20', '0.30', '0.40', '0.50', '0.60', '0.70', '0.80']
gamma2 = [0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80]
Gamma3 = ['2.20', '2.30', '2.40', '2.50', '2.60', '2.70', '2.80']
QQ={}

with open(r"/home/XXX/result_plt2/objk.pkl",'rb') as f:
    Q = pickle.load(f)

for i in Gamma2:
    for j in Gamma3:
        RD_w_0_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_0'])
        RD_w_1_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_1'])
        RD_w_0_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_0'])
        RD_w_1_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_1'])
        RD_w_0_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_0'])
        RD_w_1_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_1'])
        RD_w_0_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_0'])
        RD_w_1_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_1'])
        RD_w_0_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_0'])
        RD_w_1_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_1'])

        plt.figure()
        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
        #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
        #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
        plt.scatter([100,1000,2500,5000,10000],[RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000],label = 'training data', color='g')
        # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
        # plt.title('result',fontsize=15)        
        #ax.set_yscale('log')
        plt.xlabel(r'x',fontsize=24)
        plt.ylabel(r'y',rotation=0,fontsize=24)
        # my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
        # plt.xticks(my_x_ticks)
        # my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
        # plt.yticks(my_y_ticks)
        # plt.tick_params(axis='both',which='major',labelsize=20)
        # plt.legend()
        ax.set_xscale('log')
        ax.set_yscale('log')
        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        print(z1)
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
        Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]
        print(p1) #显示多项式
        plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line: '+str( round(z1[0], 4)),linestyle='--')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.legend()
        plt.tight_layout()
        plt.savefig(r'/home/XXX/result_plt2/drawing/gamma2_%s_gamma3_%s_RD_w_0.png'%(i,j))
        plt.close()

        plt.figure()
        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
        #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
        #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
        plt.scatter([100,1000,2500,5000,10000],[RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000],label = 'training data', color='g')
        # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
        # plt.title('result',fontsize=15)        
        #ax.set_yscale('log')
        plt.xlabel(r'x',fontsize=24)
        plt.ylabel(r'y',rotation=0,fontsize=24)
        # my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
        # plt.xticks(my_x_ticks)
        # my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
        # plt.yticks(my_y_ticks)
        # plt.tick_params(axis='both',which='major',labelsize=20)
        # plt.legend()
        ax.set_xscale('log')
        ax.set_yscale('log')
        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        print(z1)
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
        Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]
        print(p1) #显示多项式
        plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line: '+str( round(z1[0], 4)),linestyle='--')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.legend()
        plt.tight_layout()
        plt.savefig(r'/home/XXX/result_plt2/drawing/gamma2_%s_gamma3_%s_RD_w_1.png'%(i,j))
        plt.close()

print(QQ)

heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[len(Gamma2)-i-1][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_0']

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
ax.set_yticklabels(([0.8,0.7,0.6,0.5,0.4,0.3,0.2]),fontsize=12)
ax.set_xticklabels(([2.20, 2.30, 2.40, 2.50, 2.60, 2.70, 2.80]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'/home/XXX/result_plt2/drawing/S_w_0.png')
plt.close()
plt.clf()

stars_0 = []
heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
loop = [0,0,1,1,2,2,3]
for i in range(len(Gamma3)):
    val = []
    counter = loop[i]
    for j in range(3):
        val.append(QQ['gamma2_'+str(Gamma2[j+counter+1])+'_gamma3_'+str(Gamma3[i])+'_'+'k_0'])
    z1 = np.polyfit([0.3+0.1*counter,0.4+0.1*counter,0.5+0.1*counter], val, 1) # 用1次多项式拟合，可改变多项式阶数；
    print(z1)
    stars_0.append(-z1[1]/z1[0])
    # loop = loop + 1
print(stars_0)

heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[i][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_0']

XX,YY = np.meshgrid([2.20, 2.30, 2.40, 2.50, 2.60, 2.70, 2.80],[0.2,0.3,0.4,0.5,0.6,0.7,0.8])

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,heatmap_k, cmap='RdBu',shading='auto',vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([2.20, 2.30, 2.40, 2.50, 2.60, 2.70, 2.80]) #设置x轴刻度
ax.set_yticks([0.8,0.7,0.6,0.5,0.4,0.3,0.2]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im,ticks = [-0.6,0,0.6])# , cax=cax)
plt.plot([2.20, 2.30, 2.40, 2.50, 2.60, 2.70, 2.80],stars_0,'k*',markersize=14)
plt.plot([2.15,2.85],[0.325,0.675],'k--',linewidth=1.5)
# x_aux=[1.5,2.2]
# y_aux=[0.0,0.35]
# plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'/home/XXX/result_plt2/drawing/S_w_0_xu.png')
plt.close()
plt.clf()



heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[len(Gamma2)-i-1][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_1']

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
ax.set_yticklabels(([0.8,0.7,0.6,0.5,0.4,0.3,0.2]),fontsize=12)
ax.set_xticklabels(([2.20, 2.30, 2.40, 2.50, 2.60, 2.70, 2.80]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'/home/XXX/result_plt2/drawing/S_w_1.png')
plt.close()
plt.clf()



stars_0 = []
heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    if gamma2[i] >= 0.5:
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+2])+'_'+'k_1'])
        z1 = np.polyfit([2.4,2.5,2.6], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if gamma2[i] == 0.4:
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+3])+'_'+'k_1'])
        z1 = np.polyfit([2.5,2.6,2.7], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if gamma2[i] == 0.3:
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+4])+'_'+'k_1'])
        z1 = np.polyfit([2.6,2.7,2.8], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if gamma2[i] == 0.2:
        val = []
        for j in range(2):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+5])+'_'+'k_1'])
        z1 = np.polyfit([2.7,2.8], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    stars_0.append(-z1[1]/z1[0])
print(stars_0)

heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[i][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_1']

XX,YY = np.meshgrid([2.20, 2.30, 2.40, 2.50, 2.60, 2.70, 2.80],[0.2,0.3,0.4,0.5,0.6,0.7,0.8])

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,heatmap_k, cmap='RdBu',shading='auto',vmin=-0.16,vmax=0.16) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([2.20, 2.30, 2.40, 2.50, 2.60, 2.70, 2.80]) #设置x轴刻度
ax.set_yticks([0.8,0.7,0.6,0.5,0.4,0.3,0.2]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.tick_params(labelsize=8)
plt.colorbar(im,ticks = [-0.10,0,0.10])# , cax=cax)
plt.plot(stars_0,[0.2,0.3,0.4,0.5,0.6,0.7,0.8],'k*',markersize=14)
plt.plot([2.5,2.5],[0.5,0.85],'k--',linewidth=1.5)
x_aux=[2.5,2.825]
y_aux=[0.5,0.175]
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'/home/XXX/result_plt2/drawing/S_w_1_xu.png')
plt.close()
plt.clf()