import os
import pickle
import matplotlib
import matplotlib.pyplot as plt
from seaborn.matrix import heatmap
matplotlib.use('Agg')
import numpy as np
import shutil 
import math
import seaborn as sns
from matplotlib.lines import Line2D
# 后面记得改回来m=10000
# Gamma2 = [0.2,0.3,0.4,0.5,0.6,0.7,0.8]
# Gamma3 = ['2.20','2.25','2.30','2.35','2.40','2.45','2.50','2.55','2.60','2.65','2.70','2.75','2.80']
Gamma2 = ['-0.30', '-0.20', '-0.10', '0.00', '0.10', '0.20', '0.30']
Gamma3 = ['0.90', '1.10', '1.30', '1.50', '1.70', '1.90', '2.10']
QQ={}

with open(r"/home/XXX/result_plt1_new/objk.pkl",'rb') as f:
    Q = pickle.load(f)

for i in Gamma2:
    for j in Gamma3:
        print("**************************************************")
        print("i=",i)
        print("j=",j)
        QQ['gamma2_'+str(i)+'_gamma3_'+str(j)+'_10000_'+'lr'] = np.log(np.mean(Q['gam2_'+str(i)+'_gam3_'+str(j)+'_10000_'+'lr']))


print(QQ)
heatmap_k = np.zeros(shape=(len(Gamma2),len(Gamma3)))
for i in range(len(Gamma2)):
    for j in range(len(Gamma3)):
        heatmap_k[len(Gamma2)-i-1][j] = QQ['gamma2_'+str(Gamma2[i])+'_gamma3_'+str(Gamma3[j])+'_10000_'+'lr']

plt.rcParams['savefig.dpi'] = 200 #图片像素
plt.rcParams['figure.dpi'] = 200 #分辨率
fig,ax = plt.subplots()
ax = sns.heatmap(heatmap_k,linewidths = 0.05,cmap='RdBu') # vmin=-1,vmax=1,cmap='YlGnBu_r'),xticklabels = np.arange(Q['thredhood']),yticklabels = np.arange(Q['thredhood']))
# ax.set_xticks(np.arange(40)) #设置x轴刻度
# ax.set_yticks(np.arange(40)) #设置y轴刻度
ax.xaxis.set_ticks_position('top')
ax.set_yticklabels(([-0.30, -0.20, -0.10, 0.00, 0.10, 0.20, 0.30]),fontsize=12)
ax.set_xticklabels(([0.9, 1.1, 1.3, 1.5, 1.7, 1.9, 2.1]),fontsize=12)
#ax.set_title("cos distance: relu",fontsize=16)
plt.savefig(r'/home/XXX/result_plt1_new/drawing/log_lr.png')
plt.close()
plt.clf()