import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns

Gamma2 = [0.2,0.3,0.4,0.5,0.6,0.7,0.8]
Gamma3 = ['2.20','2.25','2.30','2.35','2.40','2.45','2.50','2.55','2.60','2.65','2.70','2.75','2.80']
num = [100,1000,2500,5000,10000]
QQ={}

with open(r"D:\work_xzq_doctor\20210701_threelayers_limit_width\linear_regression_average_variation/objk.pkl",'rb') as f:
    Q = pickle.load(f)

# for dic in Q:        
#     print(dic)

for i in Gamma2:
    for j in Gamma3:
        for k in num:
            sortt = np.argsort(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1])
            # regulation: 

            Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1] = Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]/(np.max(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1])-np.min(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]))
            Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1] = Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1]/(np.max(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1])-np.min(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1]))

            # Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1] = Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]/np.sqrt((np.max(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1])-np.min(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]))**2 +(np.max(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1])-np.min(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1]))**2)
            # Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1] = Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1]/np.sqrt((np.max(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1])-np.min(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]))**2 +(np.max(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1])-np.min(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1]))**2)


            print(sortt)
            # print(len(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]))
            z1 = np.polyfit(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1][sortt], Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1][sortt], 1) # 用1次多项式拟合，可改变多项式阶数；
            x_poly = np.linspace(start=np.min(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]),stop =np.max(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1]), num =100,endpoint=True)
            y_poly = x_poly**1 * z1[0] +  x_poly**0 * z1[1] # + x_poly**1 * z1[2] + x_poly**0 * z1[3] 
            
            ax = plt.gca()
            plt.rcParams['savefig.dpi'] = 200 #图片像素
            plt.rcParams['figure.dpi'] = 200 #分辨率
            plt.xlabel(r'w1',fontsize=18)
            plt.ylabel(r'w2',fontsize=18)
            plt.scatter(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][0][:-1],Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][0][:-1],color = 'r',s = 10,label='initial')
            plt.scatter(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][:-1],Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][:-1],color = 'g',s = 10,label='now')
            plt.plot(x_poly,y_poly, linewidth=2.0, color='red',label = 'auxiliary line',linestyle='--')
            plt.legend()
            plt.savefig(r'D:\work_xzq_doctor\20210701_threelayers_limit_width\linear_regression_average_variation/largest_w1_w2_layer2to3_%s.png'%('gamma_'+str(i)+'_gam3_'+j+'_'+str(k)))
            plt.close()

            # compute average distance
            dis_all = 0
            numb = 0
            for ctr in range(len(Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1])-1):
                dis_all = dis_all + ( Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w2_of_layer_2to3'][-1][ctr] - Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][ctr]**1 * z1[0] -  Q['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_w1_w2_w1_of_layer_2to3'][-1][ctr]**0 * z1[1] )**2
                numb = numb + 1
            # print(numb)
            if numb != k:
                print('somethint wrong')
            dis_all = dis_all / numb
            print(dis_all)

            QQ['gamma_'+str(i)+'_gam3_'+j+'_'+str(k)+'_dis'] = dis_all


FolderName  = r"D:\work_xzq_doctor\20210701_threelayers_limit_width\linear_regression_average_variation/"

# quit()

def savefile(): #保存模型参数的函数
    with open('%s/objk_dis.pkl'%(FolderName), 'wb') as f:  # Python 3: open(..., 'wb')
        pickle.dump(QQ, f, protocol=4)
    #序列化对象，将对象obj保存到文件file中去
    text_file = open("%s/objk_dis.txt"%(FolderName), "w")
    for para in QQ:
        if np.size(QQ[para])>200:
            continue
        text_file.write('%s: %s\n'%(para,QQ[para]))

savefile()
 