import random
import torch
import numpy as np
from any_tools import get_l2_norm_by_column, get_l2_norm_by_row, get_hist_graph, get_bottle_k, get_top_k

def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


def process_matrix(weight, name):
    if "out" in name:
        re_weight = weight.reshape(-1, weight.shape[-1])
    else:
        re_weight = weight.reshape(weight.shape[0], -1)
    return re_weight


def get_query(model_weight, index):
    query_name = f"Transformer/encoderblock_{index}/MultiHeadDotProductAttention_1/query/kernel"
    query_weight = np2th(model_weight[query_name])
    re_weight = process_matrix(query_weight, query_name)
    return re_weight


def get_key(model_weight, index):
    key_name = f"Transformer/encoderblock_{index}/MultiHeadDotProductAttention_1/key/kernel"
    key_weight = np2th(model_weight[key_name])
    re_weight = process_matrix(key_weight, key_name)
    return re_weight

# # def get_query(model_weight, index):
# #     query_name = f"Transformer/encoderblock_{index}/MultiHeadDotProductAttention_1/query/kernel"
# #     query_weight = np2th(model_weight[query_name])
# #     re_weight = process_matrix(query_weight, query_name)
# #     return re_weight
# # key_name = "Transformer/encoderblock_10/MultiHeadDotProductAttention_1/key/kernel"
# # value_name = "Transformer/encoderblock_1/MultiHeadDotProductAttention_1/value/kernel"
# # name = "Transformer/encoderblock_1/MlpBlock_3/Dense_0/kernel"
#
# # for i in range(12):
# #     sub_weight = value_weight[:, i, :]
# #     print(sub_weight.size())
# #     re_weight_norm = get_l2_norm_by_column(sub_weight)
# #     # get_hist_graph(re_weight_norm)
# #
# #     mean_=torch.mean(re_weight_norm)
# #     print(mean_)
# # print(weight.size())
#
#     # 计算区间边界
# bin_edges = torch.linspace(0, 768, 13)
#
# print(bin_edges)
#
# # 使用torch.histc计算落入每个区间的元素数量
# # counts = torch.histc(tensor, bins=num_bins, min=min_val, max=max_val)
#
#
# model_weight = np.load("ViT-B_16.npz")
#
#
# index = 10
# query = get_query(model_weight, index)
# key = get_key(model_weight, index)
#
# query_norm = get_l2_norm_by_column(query)
# key_norm = get_l2_norm_by_column(key)
#
# top_k = int(query_norm.shape[0] * 0.6)
#
# _, query_indiceds = torch.topk(query_norm, top_k, largest=True, sorted=True)
#
# query_indiceds = query_indiceds.float()
# counts = torch.histc(query_indiceds, bins=12, min=0, max=768)
#
# print(counts.long())
#

# _, key_indiceds = torch.topk(key_norm, top_k, largest=True, sorted=True)
# print(query_indiceds)
#
# q_index = query_indiceds.numpy()
# k_index = key_indiceds.numpy()
#
# # # 两个列表
# # list1 = [1, 2, 3, 4, 5]
# # list2 = [4, 5, 6, 7, 8]
# #
# # # 转换为 NumPy 数组（如果输入是列表，np.intersect1d 也会自动转换）
# intersection = np.intersect1d(q_index, k_index)
#
#
# print("交集结果:", len(intersection) / len(query_indiceds))
# print(query_indiceds)
# re_weight_norm = get_l2_norm_by_row(re_weight)
# get_hist_graph(re_weight_norm)





# print(indices)
#
# new_input = input[:, sorted_indices]
# new_weight = weight[sorted_indices, :]
# new_output = torch.mm(new_input, new_weight)
# print(torch.norm(output))
# print(torch.norm(new_output))


# get_hist_graph(re_weight_norm)

# count = (re_weight_norm < 0.1).sum().item()
#
# print(count)
# print(count / re_weight_norm.shape[0])


# for i in model_weight:
#     print(i)
# random.seed(42)
# torch.manual_seed(42)
# torch.cuda.manual_seed(42)
#
#
# input_size = 768
# output_size = 768
#
# input = torch.rand([100, input_size])
#
# weight = torch.rand([input_size, 300])
#
# output = torch.mm(input, weight)
# # weight = torch.tensor().float()
#
#
# weight_norm = torch.norm(weight, dim=1, p=2)
#
# input_norm = torch.norm(input, dim=0, p=2)
#
# print(input_norm.size())
# print(input_norm)
#
#
# print(weight_norm.size())
# print(weight_norm)
#
#
# multiple = input_norm * weight_norm
# print(multiple.size())
# print(multiple)
#
# top_k = 500
# _, indices = torch.topk(multiple, top_k, largest=True, sorted=True)
# sorted_indices, sorted_order = torch.sort(indices)
#
# random_indices = torch.randperm(768)[:top_k]
# random_sorted_indices, _ = torch.sort(random_indices)
# print(indices)
#
# new_input = input[:, sorted_indices]
# new_weight = weight[sorted_indices, :]
# new_output = torch.mm(new_input, new_weight)
# print(torch.norm(output))
# print(torch.norm(new_output))
#
# random_input = input[:, random_sorted_indices]
# random_weight =  weight[random_sorted_indices, :]
# random_output = torch.mm(random_input, random_weight)
# print(torch.norm(random_output))