import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns
from matplotlib.lines import Line2D
# 后面记得改回来m=10000
# Gamma2 = [0.2,0.3,0.4,0.5,0.6,0.7,0.8]
# Gamma3 = ['2.20','2.25','2.30','2.35','2.40','2.45','2.50','2.55','2.60','2.65','2.70','2.75','2.80']
Gamma2 = ['-0.30', '-0.20', '-0.10', '0.00', '0.10', '0.20', '0.30']
Gamma3 = ['0.90', '1.10', '1.30', '1.50', '1.70', '1.90', '2.10']
QQ={}

with open(r"/home/XXX/result/objk.pkl",'rb') as f:
    Q = pickle.load(f)

for i in Gamma2:
    for j in Gamma3:
        print("**************************************************")
        print("i=",i)
        print("j=",j)
        RD_w_0_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_0'])
        RD_w_1_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_1'])
        RD_w_0_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_0'])
        RD_w_1_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_1'])
        RD_w_0_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_0'])
        RD_w_1_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_1'])
        RD_w_0_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_0'])
        RD_w_1_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_1'])
        RD_w_0_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_0'])
        RD_w_1_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_1'])

        plt.figure()
        ax = plt.gca()
        plt.rcParams['savefig.dpi'] = 200 #图片像素
        plt.rcParams['figure.dpi'] = 200 #分辨率
        #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
        #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
        #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
        plt.scatter([100,1000,2500,5000,10000],[RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000],label = 'training data', color='g')
        # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
        # plt.title('result',fontsize=15)        
        #ax.set_yscale('log')
        plt.xlabel(r'x',fontsize=24)
        plt.ylabel(r'y',rotation=0,fontsize=24)
        # my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
        # plt.xticks(my_x_ticks)
        # my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
        # plt.yticks(my_y_ticks)
        # plt.tick_params(axis='both',which='major',labelsize=20)
        # plt.legend()
        ax.set_xscale('log')
        ax.set_yscale('log')
        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        print(z1)
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
        Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]
        print(p1) #显示多项式
        plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line: '+str( round(z1[0], 4)),linestyle='--')
        # plt.xlim([-4.5,4.5])
        # plt.ylim([-4.5,4.5])
        plt.legend()
        plt.tight_layout()
        plt.savefig(r'/home/XXX/result/drawing/gamma2_%s_gamma3_%s_RD_w_0.png'%(i,j))
        plt.close()

        # plt.figure()
        # ax = plt.gca()
        # plt.rcParams['savefig.dpi'] = 200 #图片像素
        # plt.rcParams['figure.dpi'] = 200 #分辨率
        # #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
        # #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
        # #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
        plt.scatter(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_0_10000]),label = 'training data', color='g')
        # # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
        # # plt.title('result',fontsize=15)        
        # #ax.set_yscale('log')
        # plt.xlabel(r'x',fontsize=24)
        # plt.ylabel(r'y',rotation=0,fontsize=24)
        # # my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
        # # plt.xticks(my_x_ticks)
        # # my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
        # # plt.yticks(my_y_ticks)
        # # plt.tick_params(axis='both',which='major',labelsize=20)
        # # plt.legend()
        # ax.set_xscale('log')
        # ax.set_yscale('log')
        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        # p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        # x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
        # Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]
        # print(p1) #显示多项式
        # plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line: '+str( round(z1[0], 4)),linestyle='--')
        # # plt.xlim([-4.5,4.5])
        # # plt.ylim([-4.5,4.5])
        # plt.legend()
        # plt.tight_layout()
        plt.savefig(r'/home/XXX/result/drawing/gamma2_%s_gamma3_%s_RD_w_1.png'%(i,j))
        plt.close()

print(QQ)
heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[len(Gamma2)-i-1][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_0']

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu',vmin=-0.85,vmax=0.85) # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
ax.set_yticklabels(([-0.30, -0.20, -0.10, 0.00, 0.10, 0.20, 0.30]),fontsize=12)
ax.set_xticklabels(([0.9, 1.1, 1.3, 1.5, 1.7, 1.9, 2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'/home/XXX/result/drawing/S_w_0.png')
plt.close()
plt.clf()

fig,ax = plt.subplots()
plt.scatter([1.3, 1.5, 1.7],[QQ['gamma2_-0.30_gamma3_1.30_k_0'], QQ['gamma2_-0.30_gamma3_1.50_k_0'], QQ['gamma2_-0.30_gamma3_1.70_k_0']],label = 'training data', color='g')
plt.savefig(r'/home/XXX/result/drawing/test.png')
plt.close()

stars_0 = []
heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    if float(Gamma2[i]) <= 0:
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+2])+'_'+'k_0'])
        z1 = np.polyfit([1.3, 1.5, 1.7], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if float(Gamma2[i]) == 0.1:
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+3])+'_'+'k_0'])
        z1 = np.polyfit([1.5, 1.7, 1.9], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if float(Gamma2[i]) == 0.2:
        val = []
        for j in range(3):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+4])+'_'+'k_0'])
        z1 = np.polyfit([1.7, 1.9, 2.1], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    if float(Gamma2[i]) == 0.3:
        val = []
        for j in range(2):
            val.append(QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j+5])+'_'+'k_0'])
        z1 = np.polyfit([1.9, 2.1], val, 1) # 用1次多项式拟合，可改变多项式阶数；
        print(z1)
    stars_0.append(-z1[1]/z1[0])
print('****************************************')
print(stars_0)


heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[i][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_'+'k_0']

XX,YY = np.meshgrid([0.9, 1.1, 1.3, 1.5, 1.7, 1.9, 2.1],[-0.30, -0.20, -0.10, 0.00, 0.10, 0.20, 0.30])

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,heatmap_k, cmap='RdBu',shading='auto',vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([0.9, 1.1, 1.3, 1.5, 1.7, 1.9, 2.1]) #设置x轴刻度
ax.set_yticks([-0.30, -0.20, -0.10, 0.00, 0.10, 0.20, 0.30]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im,ticks = [-0.6,0,0.6])# , cax=cax)
plt.plot(stars_0,[-0.30, -0.20, -0.10, 0.00, 0.10, 0.20, 0.30],'k*',markersize=14)
x_aux=[1.5,2.2]
y_aux=[0.0,0.35]
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
plt.plot([1.5-1e-5,1.5],[-0.35,0],'k--',linewidth=1.5)
ax.set_yticklabels(([-0.3,-0.2,-0.1,0.0,0.1,0.2,0.3]),fontsize=12)
ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
plt.savefig(r'/home/XXX/result/drawing/S_w_0_star.png')
plt.close()
plt.clf()