import os
import pickle
import matplotlib
import matplotlib.lines as lines
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns

###########Part1###########

Gamma2 = [0.2,0.3,0.4,0.5,0.6,0.7,0.8]
Gamma3 = ['2.20','2.25','2.30','2.35','2.40','2.45','2.50','2.55','2.60','2.65','2.70','2.75','2.80']

QQ={}

with open(r"F:\deep_study\20210812_threelayers_limit_width_zqx\result/objk.pkl",'rb') as f:
    Q = pickle.load(f)



for i in Gamma2:
    for j in Gamma3:
        RD_w_0_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_0'])
        RD_w_1_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_1'])
        RD_w_0_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_0'])
        RD_w_1_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_1'])
        RD_w_0_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_0'])
        RD_w_1_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_1'])
        RD_w_0_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_0'])
        RD_w_1_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_1'])
        RD_w_0_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_0'])
        RD_w_1_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_1'])

        if (i,j) == (0.5,'2.50'):
            plt.figure()
            ax = plt.gca()
            # plt.rcParams['savefig.dpi'] = 300 #图片像素
            # plt.rcParams['figure.dpi'] = 300 #分辨率
            plt.scatter([100,1000,2500,5000,10000],[RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000],label = 'data', color='blue')
            plt.xlabel(r'm',fontsize=22)
            # plt.ylabel(r'm',fontsize=16)
            # plt.xticks(size = 16)
            plt.yticks([10e-2,10e-1],size = 20)            
            plt.tick_params(axis='both',which='major',labelsize=20)
            ax.set_xscale('log')
            ax.set_yscale('log')
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='grey',label = 'slope='+str( round(z1[0], 4)),linestyle='-')
            # plt.xlim([-4.5,4.5])
            plt.ylim([1.9e-1,3.1e0])
            plt.legend(fontsize=20)
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_0.png'%(i,j))
            plt.close()

            plt.figure()
            ax = plt.gca()
            # plt.rcParams['savefig.dpi'] = 200 #图片像素
            # plt.rcParams['figure.dpi'] = 200 #分辨率
            #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
            #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
            #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
            plt.scatter([100,1000,2500,5000,10000],[RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000],label = 'data', color='blue')
            plt.xlabel(r'm',fontsize=22)
            # plt.ylabel(r'm',fontsize=16)
            # plt.xticks(size = 16)
            plt.yticks([1e0],size = 20)            
            plt.tick_params(axis='both',which='major',labelsize=20)
            ax.set_xscale('log')
            ax.set_yscale('log')
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]
            print(p1) #显示多项式
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='grey',label = 'slope='+str( round(z1[0], 4)),linestyle='-')
            # plt.xlim([-4.5,4.5])
            plt.ylim([4e-1,4.1e0])
            plt.legend(fontsize=20)
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_1.png'%(i,j))
            plt.close()

        if (i,j) == (0.5,'2.20'):
            plt.figure()
            ax = plt.gca()
            plt.scatter([100,1000,2500,5000,10000],[RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000],label = 'data', color='blue')
            plt.xlabel(r'm',fontsize=22)
            plt.tick_params(axis='both',which='major',labelsize=20)
            ax.set_xscale('log')
            ax.set_yscale('log')
            plt.yticks([1e-1],size = 20) 
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='grey',label = 'slope='+str( round(z1[0], 4)),linestyle='-')
            # plt.xlim([-4.5,4.5])
            plt.ylim([0.3e-1,0.5e0])
            plt.legend(fontsize=20)
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_0.png'%(i,j))
            plt.close()

            plt.figure()
            ax = plt.gca()
            # plt.rcParams['savefig.dpi'] = 200 #图片像素
            # plt.rcParams['figure.dpi'] = 200 #分辨率
            #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
            #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
            #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
            plt.scatter([100,1000,2500,5000,10000],[RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000],label = 'data', color='blue')
            plt.xlabel(r'm',fontsize=22)
            # plt.ylabel(r'm',fontsize=16)
            # plt.xticks(size = 16)
            plt.tick_params(axis='both',which='major',labelsize=20)
            ax.set_xscale('log')
            ax.set_yscale('log')
            plt.yticks([1e0],size = 20) 
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]
            print(p1) #显示多项式
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='grey',label = 'slope='+str( round(z1[0], 4)),linestyle='-')
            # plt.xlim([-4.5,4.5])
            plt.ylim([2e-1,3e0])
            plt.legend(fontsize=20)
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_1.png'%(i,j))
            plt.close()

        if (i,j) == (0.5,'2.80'):
            plt.figure()
            ax = plt.gca()
            plt.scatter([100,1000,2500,5000,10000],[RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000],label = 'data', color='blue')
            plt.xlabel(r'm',fontsize=22)
            plt.tick_params(axis='both',which='major',labelsize=20)
            ax.set_xscale('log')
            ax.set_yscale('log')
            plt.yticks([1e0],size = 20) 
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='grey',label = 'slope='+str( round(z1[0], 4)),linestyle='-')
            # plt.xlim([-4.5,4.5])
            plt.ylim([3e-1,4.5e0])
            plt.legend(fontsize=20)
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_0.png'%(i,j))
            plt.close()

            plt.figure()
            ax = plt.gca()
            # plt.rcParams['savefig.dpi'] = 200 #图片像素
            # plt.rcParams['figure.dpi'] = 200 #分辨率
            #ax.add_line(Line2D(line1_xs, line1_ys, linewidth=1.0, color='red',label = 'auxiliary line'))
            #plt.plot(Q['test_set'],Y_test_network, linewidth=2.0, color='blue',label = 'NN output', linestyle='-')
            #plt.plot(x_poly,Y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
            plt.scatter([100,1000,2500,5000,10000],[RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000],label = 'data', color='blue')
            plt.xlabel(r'm',fontsize=22)
            # plt.ylabel(r'm',fontsize=16)
            # plt.xticks(size = 16)
            plt.tick_params(axis='both',which='major',labelsize=20)
            ax.set_xscale('log')
            ax.set_yscale('log')
            plt.yticks([1e0,1e1],size = 20) 
            z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
            p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
            x_poly = np.linspace(start=100,stop =10000, num = 2000,endpoint=True)
            Y_poly = (x_poly)**z1[0] *  (math.e)**z1[1]
            QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]
            print(p1) #显示多项式
            plt.plot(x_poly,Y_poly, linewidth=2.0, color='grey',label = 'slope='+str( round(z1[0], 4)),linestyle='-')
            # plt.xlim([-4.5,4.5])
            plt.ylim([9e-1,11e0])
            plt.legend(fontsize=20)
            plt.tight_layout()
            plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/example_gamma2_%s_gamma3_%s_RD_w_1.png'%(i,j))
            plt.close()


###########Part2###########

Gamma2 = [0.2,0.3,0.4,0.5,0.6,0.7,0.8]
Gamma3 = ['2.20','2.25','2.30','2.35','2.40','2.45','2.50','2.55','2.60','2.65','2.70','2.75','2.80']
gammma3 = [2.20,2.25,2.30,2.35,2.40,2.45,2.50,2.55,2.60,2.65,2.70,2.75,2.80]

QQ={}

with open(r"F:\deep_study\20210812_threelayers_limit_width_zqx\result/objk.pkl",'rb') as f:
    Q = pickle.load(f)

for i in Gamma2:
    for j in Gamma3:
        RD_w_0_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_0'])
        RD_w_1_100 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(100)+'_RD_w_1'])
        RD_w_0_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_0'])
        RD_w_1_1000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(1000)+'_RD_w_1'])
        RD_w_0_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_0'])
        RD_w_1_2500 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(2500)+'_RD_w_1'])
        RD_w_0_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_0'])
        RD_w_1_5000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(5000)+'_RD_w_1'])
        RD_w_0_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_0'])
        RD_w_1_10000 = np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_'+str(10000)+'_RD_w_1'])

        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_0_100,RD_w_0_1000,RD_w_0_2500,RD_w_0_5000,RD_w_0_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_0'] = z1[0]

        z1 = np.polyfit(np.log([100,1000,2500,5000,10000]), np.log([RD_w_1_100,RD_w_1_1000,RD_w_1_2500,RD_w_1_5000,RD_w_1_10000]), 1) # 用7次多项式拟合，可改变多项式阶数；
        p1 = np.poly1d(z1) #得到多项式系数，按照阶数从高到低排列
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_'+'k_1'] = z1[0]

S_w_1 = []
for gam3 in Gamma3:
    S_w_1.append(QQ['gamma2_'+'0.5'+'_gamma3_'+str(gam3)+'_'+'k_0'])

plt.style.use('bmh')
plt.figure()
ax = plt.gca()
plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
for i in range(12):
    line1_xs=[gammma3[i],gammma3[i+1]]
    line1_ys=[S_w_1[i],S_w_1[i+1]]
    ax.add_line(lines.Line2D(line1_xs, line1_ys, linewidth=3.0, color='g'))
plt.scatter(gammma3,S_w_1, color='g')
plt.xlabel(r'x',fontsize=22)
plt.ylabel(r'y',rotation=0,fontsize=22)
plt.xticks([2.20,2.30,2.40,2.50,2.60,2.70,2.80],size = 16) 
plt.yticks([-0.2,0.0],size = 16) 
plt.tick_params(axis='both',which='major',labelsize=16)
plt.ylim([-0.4,0.2])
plt.tight_layout()
plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/S_w_1_gam2_0_5.png')
plt.close()

S_w_2 = []
for gam3 in Gamma3:
    S_w_2.append(QQ['gamma2_'+'0.5'+'_gamma3_'+str(gam3)+'_'+'k_1'])

plt.style.use('bmh')
plt.figure()
ax = plt.gca()
plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
for i in range(12):
    line1_xs=[gammma3[i],gammma3[i+1]]
    line1_ys=[S_w_2[i],S_w_2[i+1]]
    ax.add_line(lines.Line2D(line1_xs, line1_ys, linewidth=3.0, color='g'))
plt.scatter(gammma3,S_w_2, color='g')
plt.xlabel(r'x',fontsize=22)
plt.ylabel(r'y',rotation=0,fontsize=22)
plt.xticks([2.20,2.30,2.40,2.50,2.60,2.70,2.80],size = 16) 
plt.yticks([-0.1,0.0,0.1],size = 16) 
plt.tick_params(axis='both',which='major',labelsize=16)
plt.ylim([-0.2,0.2])
plt.tight_layout()
plt.savefig(r'F:\deep_study\20210812_threelayers_limit_width_zqx\drawing/S_w_2_gam2_0_5.png')
plt.close()