import os

#os.environ["CUDA_VISIBLE_DEVICES"] = "3"
#import clip.clip as clip
import torch
import numpy as np
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import json
def min_max_normalize(matrix):
    min_val = np.min(matrix)
    max_val = np.max(matrix)
    if max_val == min_val:
        # 如果 max_val 等于 min_val，说明矩阵中的所有值都相同，归一化结果为零矩阵
        normalized_matrix = np.zeros(matrix.shape)
    else:
        normalized_matrix = (matrix - min_val) / (max_val - min_val)
    return normalized_matrix


def min_max_normalize_tensor(tensor):
    min_val = tensor.min()
    max_val = tensor.max()
    if max_val == min_val:
        # 如果 max_val 等于 min_val，说明矩阵中的所有值都相同，归一化结果为零矩阵
        normalized_tensor = torch.zeros(tensor.shape)
    else:
        normalized_tensor = (tensor - min_val) / (max_val - min_val)
    return normalized_tensor

#abs_tensor = torch.abs(tensor)

def normalize_tensor(tensor):
    
    abs_tensor = torch.abs(tensor)
    
    normalized_tensor = abs_tensor / torch.sum(abs_tensor, axis=0, keepdims=True)
    return normalized_tensor

# model, transforms, _ = clip.load("ViT-B/16", device=device, jit=False)
# model_path="/root/data1/cil_all/cil_orth/pth/cifar100_10-10/4_metrics_withou_UCB_e10_bs64_10-6_step50_model_parameters.pth"
# model.load_state_dict(torch.load(model_path))
# model.to(device)
# #model.visual.transformer#
# a=[]
# save_dir_vis = '/root/data1/cil_all/cil_orth/heat/visual'
# save_dir_text = '/root/data1/cil_all/cil_orth/heat/text'
# for block in model.visual.transformer.resblocks:#[64,10]
#         b=block.logits_sum.to("cpu").numpy()
#         b1 = b.T
        
#         a.append(min_max_normalize(b1))
# print(len(a))
# rcParams['font.family'] = 'Times New Roman'
# for i in range(len(a)):
#     fig, ax = plt.subplots(figsize = (32,5))
#     sns.heatmap(pd.DataFrame(np.round(a[i],2), columns = range(0,64), index = range(0,10)), 
#                 annot=False, vmax=1,vmin = 0, xticklabels= True, yticklabels= True, square=True, cmap="YlGnBu")
#     ax.set_title(' layer'+str(i), fontsize = 16)
#     ax.set_ylabel('Total Number of Experts', fontsize = 12)
#     ax.set_xlabel('Input Sampl', fontsize = 12) 
    
#     # 保存图像到指定文件夹内
#     save_path = os.path.join(save_dir, 'heatmap_' + str(i) + '.png')
#     fig.savefig(save_path)
#     plt.close(fig)  # 关闭当前图以节省内存

def heat(flag,arry,t):
    save_dir_vis = '/root/data1/lxm/cil_orth/heat/visual_10-2_new3'
    save_dir_text = '/root/data1/lxm/cil_orth/heat/text_10-2_new3'
    b=arry.detach().cpu().numpy()
    b1 = b.T
    c=min_max_normalize(b1)
    fig, ax = plt.subplots(figsize = (c.shape[1]/2,c.shape[0]/2))
    sns.heatmap(pd.DataFrame(np.round(c,2), columns = range(0,c.shape[1]), index = range(0,c.shape[0])), 
                annot=False, vmax=1,vmin = 0, xticklabels= False, yticklabels= False, square=True, cmap="YlGnBu")
    #ax.set_title(' layer'+str(i), fontsize = 16)
    
    # 保存图像到指定文件夹内
    if flag=="text":
        ax.set_ylabel('Total Number of Experts', fontsize = 20)
        ax.set_xlabel('Total number of categories', fontsize = 20)
        save_path = os.path.join(save_dir_text, '1_heatmap_UCB10-2_' + str(t) + '.png')
        fig.savefig(save_path)
        plt.close(fig)  # 关闭当前图以节省内存
    else:
        
        ax.set_ylabel('Total Number of Experts', fontsize = 20)
        ax.set_xlabel('Input Sample', fontsize = 20)
        save_path = os.path.join(save_dir_vis, '1_heatmap_UCB10-2_' + str(t) + '.png')
        fig.savefig(save_path)
        plt.close(fig)  # 关闭当前图以节省内存


# a=[0.5956,0.3274,0.5771,0.1131,0.3304,0.2636,0.4664,0.4352,0.2839,0.3501]
# tensor = torch.tensor(a)
# print(normalize_tensor(tensor))




# import time

# # 记录训练开始时间
# start_time = time.time()
# for i in range(1000000): 
#     print(i)
# # 记录训练结束时间
# end_time = time.time()

# # 计算训练时长
# elapsed_time = end_time - start_time

# # 将时间转换为小时、分钟、秒
# hours, rem = divmod(elapsed_time, 3600)
# minutes, seconds = divmod(rem, 60)

# print(f"Training time: {int(hours)}h {int(minutes)}m {int(seconds)}s")


# a=f"Training time: {int(1)}h {int(1)}m {int(1)}s"

# with open('/root/data1/lxm/cil_orth/train_time.txt', 'a') as f:
#     f.write("text:"+str(1)+", ".join(map(str, a))+ "\n")