import torch
import numpy as np
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
def compute_cosine_similarity(z1, z2):
    """计算两个 tensor 的余弦相似度"""
    z1 = z1.view(-1, z1.size(-1))  # reshape 为 (64*256, 768)
    z2 = z2.view(-1, z2.size(-1))
    z1_norm = z1 / z1.norm(dim=1, keepdim=True)
    z2_norm = z2 / z2.norm(dim=1, keepdim=True)
    return (z1_norm * z2_norm).sum(dim=1)  # 返回 (64*256,) 的余弦相似度向量


def calculate_and_plot_similarity(num_steps, cache_dir, output_dir, k=10, device="cuda:0"):
    """计算相似性并绘制相似性矩阵图"""
    avg_similarity_matrix_predicted = np.zeros((num_steps, num_steps))
    avg_similarity_matrix_unpredicted = np.zeros((num_steps, num_steps))

    # 遍历所有 step 组合，计算相似度
    for i in range(num_steps):
        z1 = torch.load(os.path.join(cache_dir, f"z_step_{i}.pt")).to(device)
        mask1 = torch.load(os.path.join(cache_dir, f"mask_step_{i}.pt")).to(device)

        for j in range(i + 1, num_steps):
            z2 = torch.load(os.path.join(cache_dir, f"z_step_{j}.pt")).to(device)
            mask2 = torch.load(os.path.join(cache_dir, f"mask_step_{j}.pt")).to(device)

            # 计算未预测部分的相似度
            mask_unpredicted_common = (mask1 & mask2).view(-1)
            z1_unpredicted = z1.view(-1, z1.size(-1))[mask_unpredicted_common]
            z2_unpredicted = z2.view(-1, z2.size(-1))[mask_unpredicted_common]

            if z1_unpredicted.size(0) > 0:
                with torch.no_grad():  # 禁用梯度计算，节省显存
                    sim_unpredicted = compute_cosine_similarity(z1_unpredicted, z2_unpredicted).cpu().numpy()
                avg_similarity_matrix_unpredicted[i, j] = avg_similarity_matrix_unpredicted[
                    j, i] = sim_unpredicted.mean()

            # 释放未使用的显存
            del z1_unpredicted, z2_unpredicted, mask_unpredicted_common
            torch.cuda.empty_cache()

            # 计算已预测部分的相似度
            mask_predicted_common = (~mask1 & ~mask2).view(-1)
            z1_predicted = z1.view(-1, z1.size(-1))[mask_predicted_common]
            z2_predicted = z2.view(-1, z2.size(-1))[mask_predicted_common]

            if z1_predicted.size(0) > 0:
                with torch.no_grad():  # 禁用梯度计算，节省显存
                    sim_predicted = compute_cosine_similarity(z1_predicted, z2_predicted).cpu().numpy()
                avg_similarity_matrix_predicted[i, j] = avg_similarity_matrix_predicted[j, i] = sim_predicted.mean()

            # 释放未使用的显存
            del z1_predicted, z2_predicted, mask_predicted_common
            torch.cuda.empty_cache()

            # 释放 z2 和 mask2，释放显存
            del z2, mask2
            torch.cuda.empty_cache()

        # 释放 z1 和 mask1，释放显存
        del z1, mask1
        torch.cuda.empty_cache()

        print(f"Step {i} done")

    # 确保输出目录存在
    os.makedirs(output_dir, exist_ok=True)

    # 绘制相似性矩阵图并保存
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    # 未预测部分相似度矩阵图
    im1 = axs[0].imshow(avg_similarity_matrix_unpredicted, cmap='viridis', vmin=-1, vmax=1)
    axs[0].set_title('Unpredicted Average Cosine Similarity')
    fig.colorbar(im1, ax=axs[0])
    plt.savefig(os.path.join(output_dir, 'unpredicted_similarity_matrix.png'))

    # 已预测部分相似度矩阵图
    im2 = axs[1].imshow(avg_similarity_matrix_predicted, cmap='viridis', vmin=-1, vmax=1)
    axs[1].set_title('Predicted Average Cosine Similarity')
    fig.colorbar(im2, ax=axs[1])
    plt.savefig(os.path.join(output_dir, 'predicted_similarity_matrix.png'))

    plt.close(fig)  # 关闭图像，释放内存
