import os
import argparse
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

ATTENTION_Q = ".attn.query."
ATTENTION_K = ".attn.key."
ATTENTION_V = ".attn.value."
ATTENTION_OUT = ".attn.out."
FC_0 = ".ffn.fc1."
FC_1 = ".ffn.fc2."

plt.switch_backend('Agg')


def get_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_name", default="kitti")
    parser.add_argument("--output_dir", default="output_lora", type=str)  # -aug all-no-res
    # parser.add_argument("--model_name", default="kitti 82.42_original_rank16.pth", type=str)  # -aug all-no-res

    # parser.add_argument("--model_name", default="kitti 80.03_original_rank8.pth", type=str)  # -aug all-no-res

    # parser.add_argument("--model_name", default="kitti 79.89_original_rank128.pth", type=str)  # -aug all-no-res

    parser.add_argument("--model_name", default="kitti 81.01111.pth", type=str)  # -aug all-no-res
    args = parser.parse_args()
    return args


def load_lora(args):
    path = os.path.join(args.output_dir, args.dataset_name, args.model_name)
    load_dict = torch.load(path)
    lora_weight = {}
    for key, value in load_dict.items():
        lora_weight[key] = value
    return lora_weight


def cosine_distance_to_angle(matrix):
    """
    计算二维tensor每两列之间的余弦距离并转换为角度

    返回:
    angles: 所有列对之间的角度值列表
    """
    # # 转置矩阵，使列为样本
    # sim_matrix = F.cosine_similarity(matrix.unsqueeze(1), matrix.unsqueeze(0), dim=2)
    # print(sim_matrix)
    #
    # # 获取上三角部分（不包括对角线）
    # n_cols = sim_matrix.size(0)
    # triu_indices = torch.triu_indices(n_cols, n_cols, offset=1)
    # pairwise_similarities = sim_matrix[triu_indices[0], triu_indices[1]]
    #
    # # 找出最小和最大的相似度
    # min_sim = pairwise_similarities.min().item()
    # max_sim = pairwise_similarities.max().item()
    #
    # print(f"张量形状: {matrix.shape}")
    # print(f"所有两两列之间的余弦相似度: {pairwise_similarities}")
    # print(f"所有两两列之间的余弦相似度: {pairwise_similarities.size()}")
    # print(f"最小列余弦相似度: {min_sim:.4f}")
    # print(f"最大列余弦相似度: {max_sim:.4f}")

    # 计算所有列对之间的余弦相似度
    sim_list = []
    angles = []
    n_samples = matrix.shape[0]
    for i in range(n_samples):
        for j in range(i + 1, n_samples):
            # 计算余弦相似度
            cos_sim = F.cosine_similarity(matrix[i].unsqueeze(0),
                                          matrix[j].unsqueeze(0))
            sim_list.append(cos_sim.cpu())
    #         # 将余弦相似度转换为角度（弧度转角度）
            angle_rad = torch.acos(cos_sim.clamp(-1.0, 1.0))
            angle_deg = torch.rad2deg(angle_rad)

            angles.append(angle_deg.item())


    # array = np.array(sim_list)
    # print(np.max(array))
    # print(np.min(array))

    an_list = np.array(angles)
    print(np.max(an_list))
    print(np.min(an_list))
    return an_list


def visualize_angle_distribution(angles, title="角度分布直方图"):
    """
    可视化角度值的分布

    参数:
    angles: 角度值列表
    title: 图表标题
    """
    plt.figure(figsize=(5,3))

    # 绘制直方图
    plt.hist(angles, bins=30, alpha=0.6,range=(0, 180), color='skyblue',
             edgecolor='black', density=True)
    #
    # # 添加核密度估计
    from scipy.stats import gaussian_kde
    kde = gaussian_kde(angles)
    x_range = np.linspace(min(angles), max(angles), 100)
    plt.plot(x_range, kde(x_range), 'r-', linewidth=1, label='kde')
    plt.plot(x_range)
    plt.ylim(0, 0.04)
    plt.xlabel('Angle')
    # plt.ylabel('密度')
    # plt.title(title)
    plt.grid(alpha=0.3)
    plt.legend()
    plt.savefig("test", dpi=300, bbox_inches='tight')
    plt.close()

    print(f"分布图已保存到: test")
    # # 添加统计信息
    # mean_angle = np.mean(angles)
    # median_angle = np.median(angles)
    # plt.axvline(mean_angle, color='green', linestyle='--',
    #             label=f'均值: {mean_angle:.2f}°')
    # plt.axvline(median_angle, color='orange', linestyle='--',
    #             label=f'中位数: {median_angle:.2f}°')
    # 
    # plt.legend()
    # plt.show()
    # 
    # # 打印统计信息
    # print(f"角度统计信息:")
    # print(f"最小值: {np.min(angles):.2f}°")
    # print(f"最大值: {np.max(angles):.2f}°")
    # print(f"均值: {mean_angle:.2f}°")
    # print(f"中位数: {median_angle:.2f}°")
    # print(f"标准差: {np.std(angles):.2f}°")


def analyze_lora_weight(model_dict, index, type):
    layer_name = f"transformer.encoder.layer.{index}" + type
    lora_A_name = layer_name + "Lora_A"
    A_weight = model_dict[lora_A_name]

    # lora_B_name = layer_name + "Lora_B"
    # B_weight = model_dict[lora_B_name]
    #
    # matrix = A_weight.T @ B_weight.T
    #
    # _,s,_ = torch.svd(matrix)
    # print(A_weight.size())
    # print(B_weight.size())
    # print(s[:32])

    angles = cosine_distance_to_angle(A_weight)
    visualize_angle_distribution(angles)

if __name__ == '__main__':
    args = get_args_parser()
    lora_weight = load_lora(args)
    analyze_lora_weight(lora_weight, 0, ATTENTION_Q)


    # torch.manual_seed(42)
    # matrix1 = torch.normal(0, 1, size=(768, 768))
    # angels = cosine_distance_to_angle(matrix1.T)
    # visualize_angle_distribution(angels)
