import torch
import torch.nn.functional as F
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os

file_name = "suzuki_base"
file_path = f"/mnt/shared-storage-user/caipengxiang/workspace/ChemBOMAS/info_emcoder/saved_data_maps/dry_{file_name}_embedding.pt"

def get_sim_matrix(data_maps, file_name=None, save_path=None, num_clusterd=3):
    """
    计算相似度矩阵并保存
    """
    if num_clusterd == 3:
        names = list(data_maps.keys())
        embs  = torch.stack([data_maps[k] for k in names])

        # ---------------- 计算余弦相似度矩阵 ----------------
        embs_norm = F.normalize(embs, p=2, dim=1)
        cos_sim   = embs_norm @ embs_norm.T

        cos_sim_np = cos_sim.cpu().numpy()

        if save_path:
            os.makedirs(save_path, exist_ok=True)
            cos_df = pd.DataFrame(cos_sim_np, index=names, columns=names)

            # ---------------- 画热力图 ----------------
            plt.figure(figsize=(max(6, len(names)*0.3), max(6, len(names)*0.3)))
            sns.heatmap(
                cos_df,
                cmap="coolwarm",
                annot=False,
                square=True,
                cbar_kws={"shrink": 0.7}
            )
            plt.title("Cosine Similarity Heatmap")
            plt.tight_layout()
            plt.savefig(f"{save_path}/cos_sim_heatmap_{file_name}.png")

            # ---------------- 保存相似度矩阵 ----------------
            cos_df.to_csv(f"{save_path}/cos_sim_matrix_{file_name}.csv")
        return cos_sim_np, None
    
    elif num_clusterd > 3:
        names = list(data_maps.keys())
        # remove nothing
        names.remove("nothing")
        embs  = torch.stack([data_maps[k] for k in names])

        # ---------------- 计算余弦相似度矩阵 ----------------
        embs_norm = F.normalize(embs, p=2, dim=1)
        cos_sim   = embs_norm @ embs_norm.T
        cos_sim_np = cos_sim.cpu().numpy()

        nothing_emb = data_maps["nothing"].unsqueeze(0)
        nothing_emb_norm = F.normalize(nothing_emb, p=2, dim=1)

        nothing_sim = torch.mm(nothing_emb_norm, embs_norm.T).squeeze(0).cpu().numpy()

        if save_path:
            os.makedirs(save_path, exist_ok=True)
            cos_df = pd.DataFrame(cos_sim_np, index=names, columns=names)

            # ---------------- 画热力图 ----------------
            plt.figure(figsize=(max(6, len(names)*0.3), max(6, len(names)*0.3)))
            sns.heatmap(
                cos_df,
                cmap="coolwarm",
                annot=False,
                square=True,
                cbar_kws={"shrink": 0.7}
            )
            plt.title("Cosine Similarity Heatmap")
            plt.tight_layout()
            plt.savefig(f"{save_path}/cos_sim_heatmap_{file_name}.png")

            # ---------------- 保存相似度矩阵 ----------------
            cos_df.to_csv(f"{save_path}/cos_sim_matrix_{file_name}.csv")
        
        return cos_sim_np, nothing_sim