import torch
from torch.distributions import Normal
import numpy as np
from  Pruning_Config.loader_and_pruning import get_global_FFN_mask_by_norm
import torch.nn.functional as F
import math


ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"


# 返回模型参数
def get_model_weight(pth):
    weight = np.load(pth)
    model_dict = {}
    for name, alue in weight.items():
        model_dict[name] = torch.from_numpy(alue)

    return model_dict


def weight_svd(weight, head_dim):
    assert weight.shape != 2, "wrong weight size, not equal to 2"
    U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
    U_head = U[:, :head_dim]
    S_head = S[:head_dim]
    Vh_head = Vh[:head_dim, :]
    print(torch.norm(Vh_head, dim=1))
    recon2 = U @ torch.diag(S) @ Vh
    error2 = torch.norm(recon2 - weight)
    print(error2)
    print('-----------------------')
    return S

# 分析VO的结=】
def analyze_VO_(model_dict):
    cross_svd = []

    # for index in range(12):
    index = 0
    layer_name = f"Transformer/encoderblock_{index}/"
    value_name = layer_name + ATTENTION_V + "/kernel"
    output_name = layer_name + ATTENTION_OUT + "/kernel"
    fc1_name = layer_name + FC_0 + "/kernel"

    value_weight = model_dict[value_name]
    output_weight = model_dict[output_name]
    value_weight_reshape = value_weight.permute(1, 0, 2)
    print("new value size : ", value_weight_reshape.size())
    inter_layer_list = []
    for head_i in range(value_weight_reshape.shape[0]):
        head_weight = torch.matmul(value_weight_reshape[head_i], output_weight[head_i])
        svd_s = weight_svd(head_weight, 64)
        inter_layer_list.append(svd_s)
    layer_svd = torch.cat(inter_layer_list)
    cross_svd.append(layer_svd)

    total_svd = torch.stack(cross_svd)
    print(total_svd.size())
    return total_svd


def column_similarity_cosine(tensor):
    """
    计算二维张量每两列之间的余弦相似度，并返回最小和最大相似度。

    参数:
        tensor (torch.Tensor): 输入的二维张量。

    返回:
        min_sim (float): 最小的余弦相似度。
        max_sim (float): 最大的余弦相似度。
        pairwise_similarities (torch.Tensor): 所有列对的相似度值
    """
    # 转置张量，使列变为行
    tensor_t = tensor.T  # 形状从 [n_rows, n_cols] 变为 [n_cols, n_rows]

    # 计算列向量之间的余弦相似度
    # unsqueeze操作将列向量转换为可以广播的形式
    sim_matrix = F.cosine_similarity(tensor_t.unsqueeze(1), tensor_t.unsqueeze(0), dim=2)
    print(sim_matrix)

    # 获取上三角部分（不包括对角线）
    n_cols = sim_matrix.size(0)
    triu_indices = torch.triu_indices(n_cols, n_cols, offset=1)
    pairwise_similarities = sim_matrix[triu_indices[0], triu_indices[1]]

    # 找出最小和最大的相似度
    min_sim = pairwise_similarities.min().item()
    max_sim = pairwise_similarities.max().item()

    print(f"张量形状: {tensor.shape}")
    print(f"所有两两列之间的余弦相似度: {pairwise_similarities}")
    print(f"最小列余弦相似度: {min_sim:.4f}")
    print(f"最大列余弦相似度: {max_sim:.4f}")

    return min_sim, max_sim, pairwise_similarities

def analyze_FC1_(model_dict):
    index = 0

    layer_name = f"Transformer/encoderblock_{index}/"
    fc1_name = layer_name + FC_0 + "/kernel"
    value_name = layer_name + ATTENTION_V + "/kernel"
    fc1_weight = model_dict[value_name]
    head1 = fc1_weight[:, 11, :]
    print(head1.size())
    # fc1_weight.to("cuda:0")
    # print(fc1_weight.size())
    min_sim, max_sim, pairwise_similarities = column_similarity_cosine(head1)
    # print(min_sim)
    # print(max_sim)
    # print(pairwise_similarities.size())


if __name__ == '__main__':
    pth =  "ViT-B_16.npz"
    model_dict = get_model_weight(pth)
    analyze_FC1_(model_dict)


    # ten = torch.tensor([[1,2,3],[2,3,4]]).float()
    # min, max, matrix = column_similarity_cosine(ten)

    # get_global_FFN_mask_by_norm(total_svd, 0.3)
