import os
import pickle
import matplotlib
import matplotlib.lines as lines
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns

###########Part1###########

Gamma2 = ['-0.3','-0.2','-0.1','+0.0','+0.1','+0.2','+0.3']
Gamma3 = ['0.90','1.10','1.30','1.50','1.70','1.90','2.10']

QQ={}

with open(r"F:\deep_study\20210824_threelayers_limit_width_zqx\result/objk.pkl",'rb') as f:
    Q = pickle.load(f)

for i in Gamma2:
    for j in Gamma3:
        RD_w_0_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_0'])
        RD_w_1_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_1'])
        RD_w_0_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_0'])
        RD_w_1_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_1'])
        RD_w_0_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_0'])
        RD_w_1_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_1'])
        RD_w_0_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_0'])
        RD_w_1_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_1'])
        RD_w_0_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_0'])
        RD_w_1_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_1'])

        if (i,j) == ('+0.0','0.90'):
            plt.figure()
            ax = plt.gca()
            # plt.rcParams['savefig.dpi'] = 300 #图片像素
            # plt.rcParams['figure.dpi'] = 300 #分辨率
            plt.scatter([100,1000,2500,5000,10000],[RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000],label = 'data', color='blue')
            plt.xlabel(r'm',fontsize=20)
            # plt.ylabel(r'm',fontsize=16)
            # plt.xticks(size = 16)
            plt.yticks([0.01],size = 20)            
            plt.tick_params(axis='both',which='major',labelsize=20)
            ax.set_xscale('log')
            ax.set_yscale('log')
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='grey',label = 'slope='+str( round(z1[0], 4)),linestyle='-')
            # plt.xlim([-4.5,4.5])
            plt.ylim([0.003,0.07])
            plt.legend(fontsize=20)
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_0.png'%(i,j))
            plt.close()

            plt.figure()
            ax = plt.gca()
            plt.rcParams['savefig.dpi'] = 200 #图片像素
            plt.rcParams['figure.dpi'] = 200 #分辨率
            #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
            #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
            #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
            plt.scatter([100,1000,2500,5000,10000],[RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000],label = 'training data', color='g')
            # plt.scatter(x_train,Y_net_train,label = 'train_net', color='y')
            # plt.title('result',fontsize=15)        
            #ax.set_yscale('log')
            plt.xlabel(r'x',fontsize=24)
            plt.ylabel(r'y',rotation=0,fontsize=24)
            # my_x_ticks = [-1.0,-0.5,0,0.5,1.0,1.5]
            # plt.xticks(my_x_ticks)
            # my_y_ticks = [-1.0,-0.5,0.0,0.5,1.0]
            # plt.yticks(my_y_ticks)
            # plt.tick_params(axis='both',which='major',labelsize=20)
            # plt.legend()
            ax.set_xscale('log')
            ax.set_yscale('log')
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            print(z1)
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]
            print(p1) #显示多项式
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line: '+str( round(z1[0], 4)),linestyle='--')
            # plt.xlim([-4.5,4.5])
            # plt.ylim([-4.5,4.5])
            plt.legend()
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_1.png'%(i,j))
            plt.close()

###########Part2###########

Gamma2 = ['-0.3','-0.2','-0.1','+0.0','+0.1','+0.2','+0.3']
Gamma3 = ['0.90','1.10','1.30','1.50','1.70','1.90','2.10']
gammma3 = [0.9,1.1,1.3,1.5,1.7,1.9,2.1]
QQ={}

with open(r"F:\deep_study\20210824_threelayers_limit_width_zqx\result/objk.pkl",'rb') as f:
    Q = pickle.load(f)

for i in Gamma2:
    for j in Gamma3:
        RD_w_0_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_0'])
        RD_w_1_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_1'])
        RD_w_0_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_0'])
        RD_w_1_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_1'])
        RD_w_0_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_0'])
        RD_w_1_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_1'])
        RD_w_0_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_0'])
        RD_w_1_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_1'])
        RD_w_0_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_0'])
        RD_w_1_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_1'])

        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]

        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]

S_w_1 = []
for gam3 in Gamma3:
    S_w_1.append(QQ['gamma2_'+'+0.0'+'_gamma3_'+str(gam3)+'_'+'k_0'])

plt.style.use('bmh')
plt.figure()
ax = plt.gca()
plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
for i in range(6):
    line1_xs=[gammma3[i],gammma3[i+1]]
    line1_ys=[S_w_1[i],S_w_1[i+1]]
    ax.add_line(lines.Line2D(line1_xs, line1_ys, linewidth=3.0, color='g'))
plt.scatter([0.9,1.1,1.3,1.5,1.7,1.9,2.1],S_w_1, color='g')
plt.xlabel(r'x',fontsize=22)
plt.ylabel(r'y',rotation=0,fontsize=22)
plt.xticks([1.1,1.5,1.9],size = 16) 
plt.yticks([-0.4,0.0],size = 16) 
plt.tick_params(axis='both',which='major',labelsize=16)
plt.ylim([-0.7,0.3])
plt.tight_layout()
plt.savefig(r'F:\deep_study\20210824_threelayers_limit_width_zqx\drawing/S_w_1_gam2_0.png')
plt.close()