import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns

Gamma2 = [0.2,0.3,0.4,0.5,0.6,0.7,0.8]
Gamma3 = ['2.20','2.25','2.30','2.35','2.40','2.45','2.50','2.55','2.60','2.65','2.70','2.75','2.80']
name = ['objk_gamma_0.2.pkl','objk_gamma_0.3.pkl','objk_gamma_0.4.pkl','objk_gamma_0.5.pkl','objk_gamma_0.6.pkl','objk_gamma_0.7.pkl','objk_gamma_0.8.pkl']
num = [100,1000,2500,5000,10000]
mean_100 = np.zeros(shape=(len(Gamma2),len(Gamma3)))
mean_1000 = np.zeros(shape=(len(Gamma2),len(Gamma3)))
mean_2500 = np.zeros(shape=(len(Gamma2),len(Gamma3)))
mean_5000 = np.zeros(shape=(len(Gamma2),len(Gamma3)))
mean_10000 = np.zeros(shape=(len(Gamma2),len(Gamma3)))
QQ={}
print('start')
# for na in name:
#     with open(r"D:\work_xzq_doctor\20210701_threelayers_limit_width\drawing_cos_result/%s"%(na),'rb') as f:
#         Q = pickle.load(f)
ii = 0
jj = 0
with open(r"E:\3090深度学习资料\F盘\20210812_threelayers_limit_width_zqx\result_cos/objk_1.pkl",'rb') as f:
    Q1 = pickle.load(f)
with open(r"E:\3090深度学习资料\F盘\20210812_threelayers_limit_width_zqx\result_cos/objk_2.pkl",'rb') as f:
    Q2 = pickle.load(f)

# RR["gam2_"+ii+"_"+"gam3_"+jj+"_"+kk+"_"+'cos']

for i in Gamma2:
    jj = 0
    for j in Gamma3:
        for k in num:
            if i == 0.2 or i == 0.3 or i == 0.4 or i == 0.5:
                cos_distance_matrix = Q1['gam2_'+str(i)+'_gam3_'+j+'_'+str(k)+'_cos']
            if i == 0.6 or i == 0.7 or i == 0.8:
                cos_distance_matrix = Q2['gam2_'+str(i)+'_gam3_'+j+'_'+str(k)+'_cos']

            # print(np.shape(cos_distance_matrix))
            # print(np.mean(np.abs(cos_distance_matrix)))
            # plt.figure()
            # ax = plt.gca()
            # plt.figure()
            # plt.imshow(cos_distance_matrix,cmap='YlGnBu_r')
            # plt.colorbar()
            # ax.xaxis.set_ticks_position('top')
            # plt.title("cos distance: relu",fontsize=16)
            # plt.clim(-1, 1)
            # plt.savefig((r'D:\work_xzq_doctor\20210701_threelayers_limit_width\drawing_cos_result/heatmap_%s_%s_%s.png'%(str(i),j,str(k))))
            # plt.close()
            # plt.clf()

            if k == 100:
                mean_100[ii][jj] = np.mean(np.mean(np.abs(cos_distance_matrix)))
            if k == 1000:
                mean_1000[ii][jj] = np.mean(np.mean(np.abs(cos_distance_matrix)))
            if k == 2500:
                mean_2500[ii][jj] = np.mean(np.mean(np.abs(cos_distance_matrix)))
            if k == 5000:
                mean_5000[ii][jj] = np.mean(np.mean(np.abs(cos_distance_matrix)))
            if k == 10000:
                mean_10000[ii][jj] = np.mean(np.mean(np.abs(cos_distance_matrix)))
        
        jj = jj + 1
    ii = ii + 1

XX,YY = np.meshgrid([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8],[0.2,0.3,0.4,0.5,0.6,0.7,0.8])

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,mean_100, cmap='RdBu' ,shading='auto' ) #,vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8]) #设置x轴刻度
ax.set_yticks([0.8,0.7,0.6,0.5,0.4,0.3,0.2]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im) #,ticks = [-0.6,0,0.6])# , cax=cax)
# plt.plot([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8],stars_0,'k*',markersize=14)
plt.tick_params(labelsize=8)
# plt.colorbar(im,ticks = [-0.10,0,0.10])# , cax=cax)
# plt.plot(stars_0,[0.2,0.3,0.4,0.5,0.6,0.7,0.8],'k*',markersize=14)
plt.plot([2.5,2.5],[0.5,0.85],'k--',linewidth=1.5)
x_aux=[2.5,2.825]
y_aux=[0.5,0.175]
plt.xlabel(r'$\gamma_3$',fontsize=24)
plt.ylabel(r'$\gamma_2$',fontsize=24)
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# x_aux=[1.5,2.2]
# y_aux=[0.0,0.35]
# plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'E:\3090深度学习资料\F盘\20210812_threelayers_limit_width_zqx\drawing_cos/100.png')
plt.close()
plt.clf()

XX,YY = np.meshgrid([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8],[0.2,0.3,0.4,0.5,0.6,0.7,0.8])

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,mean_1000, cmap='RdBu' ,shading='auto' ) #,vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8]) #设置x轴刻度
ax.set_yticks([0.8,0.7,0.6,0.5,0.4,0.3,0.2]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im) #,ticks = [-0.6,0,0.6])# , cax=cax)
# plt.plot([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8],stars_0,'k*',markersize=14)
plt.tick_params(labelsize=8)
# plt.colorbar(im,ticks = [-0.10,0,0.10])# , cax=cax)
# plt.plot(stars_0,[0.2,0.3,0.4,0.5,0.6,0.7,0.8],'k*',markersize=14)
plt.plot([2.5,2.5],[0.5,0.85],'k--',linewidth=1.5)
x_aux=[2.5,2.825]
y_aux=[0.5,0.175]
plt.xlabel(r'$\gamma_3$',fontsize=24)
plt.ylabel(r'$\gamma_2$',fontsize=24)
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# x_aux=[1.5,2.2]
# y_aux=[0.0,0.35]
# plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'E:\3090深度学习资料\F盘\20210812_threelayers_limit_width_zqx\drawing_cos/1000.png')
plt.close()
plt.clf()

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,mean_2500, cmap='RdBu' ,shading='auto' ) #,vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8]) #设置x轴刻度
ax.set_yticks([0.8,0.7,0.6,0.5,0.4,0.3,0.2]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im) #,ticks = [-0.6,0,0.6])# , cax=cax)
# plt.plot([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8],stars_0,'k*',markersize=14)
plt.tick_params(labelsize=8)
# plt.colorbar(im,ticks = [-0.10,0,0.10])# , cax=cax)
# plt.plot(stars_0,[0.2,0.3,0.4,0.5,0.6,0.7,0.8],'k*',markersize=14)
plt.plot([2.5,2.5],[0.5,0.85],'k--',linewidth=1.5)
x_aux=[2.5,2.825]
y_aux=[0.5,0.175]
plt.xlabel(r'$\gamma_3$',fontsize=24)
plt.ylabel(r'$\gamma_2$',fontsize=24)
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# x_aux=[1.5,2.2]
# y_aux=[0.0,0.35]
# plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'E:\3090深度学习资料\F盘\20210812_threelayers_limit_width_zqx\drawing_cos/2500.png')
plt.close()
plt.clf()

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,mean_5000, cmap='RdBu' ,shading='auto' ) #,vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8]) #设置x轴刻度
ax.set_yticks([0.8,0.7,0.6,0.5,0.4,0.3,0.2]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im) #,ticks = [-0.6,0,0.6])# , cax=cax)
# plt.plot([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8],stars_0,'k*',markersize=14)
plt.tick_params(labelsize=8)
# plt.colorbar(im,ticks = [-0.10,0,0.10])# , cax=cax)
# plt.plot(stars_0,[0.2,0.3,0.4,0.5,0.6,0.7,0.8],'k*',markersize=14)
plt.plot([2.5,2.5],[0.5,0.85],'k--',linewidth=1.5)
x_aux=[2.5,2.825]
y_aux=[0.5,0.175]
plt.xlabel(r'$\gamma_3$',fontsize=24)
plt.ylabel(r'$\gamma_2$',fontsize=24)
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# x_aux=[1.5,2.2]
# y_aux=[0.0,0.35]
# plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'E:\3090深度学习资料\F盘\20210812_threelayers_limit_width_zqx\drawing_cos/5000.png')
plt.close()
plt.clf()

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
im = plt.pcolor(XX,YY,mean_10000, cmap='RdBu' ,shading='auto' ) #,vmin=-0.85,vmax=0.85) 
# ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
ax.set_xticks([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8]) #设置x轴刻度
ax.set_yticks([0.8,0.7,0.6,0.5,0.4,0.3,0.2]) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(im) #,ticks = [-0.6,0,0.6])# , cax=cax)
# plt.plot([2.2,2.25,2.3,2.35,2.4,2.45,2.5,2.55,2.6,2.65,2.7,2.75,2.8],stars_0,'k*',markersize=14)
plt.tick_params(labelsize=8)
# plt.colorbar(im) # ,ticks = [-0.10,0,0.10])# , cax=cax)
# plt.plot(stars_0,[0.2,0.3,0.4,0.5,0.6,0.7,0.8],'k*',markersize=14)
plt.plot([2.5,2.5],[0.5,0.85],'k--',linewidth=1.5)
x_aux=[2.5,2.825]
y_aux=[0.5,0.175]
plt.xlabel(r'$\gamma_3$',fontsize=24)
plt.ylabel(r'$\gamma_2$',fontsize=24)
plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# x_aux=[1.5,2.2]
# y_aux=[0.0,0.35]
# plt.plot(x_aux,y_aux,'k--',linewidth=1.5)
# ax.set_yticklabels(([+0.3,0.2,0.1,0.0,-0.1,-0.2,-0.3]),fontsize=12)
# ax.set_xticklabels(([0.9,1.1,1.3,1.5,1.7,1.9,2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'E:\3090深度学习资料\F盘\20210812_threelayers_limit_width_zqx\drawing_cos/10000.png')
plt.close()
plt.clf()