import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import pandas as pd
from multiprocessing import Pool, Manager, set_start_method

def compute_cosine_similarity(z1, z2):
    """计算两个 tensor 的余弦相似度"""
    z1_norm = z1 / z1.norm(dim=1, keepdim=True)
    z2_norm = z2 / z2.norm(dim=1, keepdim=True)
    return (z1_norm * z2_norm).sum(dim=1)  # 返回 (64*256,) 的余弦相似度向量

def compute_l2_distance(z1, z2):
    """计算两个 tensor 的 L2 距离"""
    z_mean = (z1 - z2).norm(dim=1)
    z_mean_f = z_mean.mean()
    return z_mean_f


def compute_abs_distance(z1, z2):
    """计算两个 tensor 的 L2 距离"""
    z_abs_mean = (z1 - z2).abs().mean()
    return z_abs_mean

def calculate_similarity(i, j, cache_dir, device, avg_similarity_matrix_predicted, avg_similarity_matrix_unpredicted):
    print(f"Calculating similarity between step {i} and step {j}")
    """计算单对 (i, j) 的相似性，并更新共享矩阵"""
    z1 = torch.load(os.path.join(cache_dir, f"z_step_{i}.pt")).to(device)
    mask1 = torch.load(os.path.join(cache_dir, f"mask_step_{i}.pt")).to(device)
    z2 = torch.load(os.path.join(cache_dir, f"z_step_{j}.pt")).to(device)
    mask2 = torch.load(os.path.join(cache_dir, f"mask_step_{j}.pt")).to(device)

    # 计算未预测部分的相似度
    mask_unpredicted_common = (mask1 & mask2).view(-1)
    z1_unpredicted = z1.view(-1, z1.size(-1))[mask_unpredicted_common]
    z2_unpredicted = z2.view(-1, z2.size(-1))[mask_unpredicted_common]

    if z1_unpredicted.size(0) > 0:
        with torch.no_grad():  # 禁用梯度计算，节省显存
            # sim_unpredicted = compute_cosine_similarity(z1_unpredicted, z2_unpredicted).cpu().numpy()
            sim_unpredicted = compute_l2_distance(z1_unpredicted, z2_unpredicted).cpu().numpy()
        avg_similarity_matrix_unpredicted[i][j] = avg_similarity_matrix_unpredicted[j][i] = sim_unpredicted.mean()

    # 计算已预测部分的相似度
    mask_predicted_common = (~mask1 & ~mask2).view(-1)
    z1_predicted = z1.view(-1, z1.size(-1))[mask_predicted_common]
    z2_predicted = z2.view(-1, z2.size(-1))[mask_predicted_common]

    if z1_predicted.size(0) > 0:
        with torch.no_grad():  # 禁用梯度计算，节省显存
            # sim_predicted = compute_cosine_similarity(z1_predicted, z2_predicted).cpu().numpy()
            sim_predicted = compute_abs_distance(z1_predicted, z2_predicted).cpu().numpy()
        print(f"z1 abs mean: {z1_predicted.abs().mean()},"
              f" z2 abs mean: {z2_predicted.abs().mean()}, delta abs mean: {sim_predicted} \n")
        avg_similarity_matrix_predicted[i][j] = avg_similarity_matrix_predicted[j][i] = sim_predicted.mean()

    # 清理显存
    del z1, mask1, z2, mask2, z1_unpredicted, z2_unpredicted, z1_predicted, z2_predicted, mask_unpredicted_common, mask_predicted_common
    torch.cuda.empty_cache()

def calculate_and_plot_similarity(num_steps, cache_dir, output_dir, k=10, device="cuda:0", max_concurrent_tasks=2):
    """计算相似性并绘制相似性矩阵图"""
    manager = Manager()
    avg_similarity_matrix_predicted = manager.list([manager.list([0] * num_steps) for _ in range(num_steps)])
    avg_similarity_matrix_unpredicted = manager.list([manager.list([0] * num_steps) for _ in range(num_steps)])
    set_start_method("spawn", force=True)
    # 使用多进程池并行计算相似度，限制最大并发数量
    with Pool(processes=max_concurrent_tasks) as pool:
        tasks = [
            pool.apply_async(
                calculate_similarity,
                args=(i, j, cache_dir, device, avg_similarity_matrix_predicted, avg_similarity_matrix_unpredicted)
            )
            for i in range(num_steps) for j in range(i + 1, num_steps)
        ]
        for task in tasks:
            task.get()  # 等待所有任务完成

    # 将共享的 list 转为 numpy 数组
    avg_similarity_matrix_predicted = np.array(avg_similarity_matrix_predicted)
    avg_similarity_matrix_unpredicted = np.array(avg_similarity_matrix_unpredicted)

    # 确保输出目录存在
    os.makedirs(output_dir, exist_ok=True)

    # 保存数据到 Excel 文件
    with pd.ExcelWriter(os.path.join(output_dir, "similarity_matrices.xlsx")) as writer:
        pd.DataFrame(avg_similarity_matrix_unpredicted).to_excel(writer, sheet_name="Unpredicted Similarity")
        pd.DataFrame(avg_similarity_matrix_predicted).to_excel(writer, sheet_name="Predicted Similarity")

    # 绘制相似性矩阵图并保存
    fig, axs = plt.subplots(1, 2, figsize=(16, 8), dpi=1024)  # 增加图片分辨率

    # 未预测部分相似度矩阵图
    im1 = axs[0].imshow(avg_similarity_matrix_unpredicted, cmap='viridis', vmin=-1, vmax=1)
    axs[0].set_title('Unpredicted Average Cosine Similarity')
    fig.colorbar(im1, ax=axs[0])
    plt.savefig(os.path.join(output_dir, 'unpredicted_similarity_matrix.png'), dpi=1024)

    # 已预测部分相似度矩阵图
    im2 = axs[1].imshow(avg_similarity_matrix_predicted, cmap='viridis', vmin=-1, vmax=1)
    axs[1].set_title('Predicted Average Cosine Similarity')
    fig.colorbar(im2, ax=axs[1])
    plt.savefig(os.path.join(output_dir, 'predicted_similarity_matrix.png'), dpi=1024)

    plt.close(fig)  # 关闭图像，释放内存
