import torch
from collections import Counter
import math
import time
import numpy as np

def cal_score(features, batch_size):
    
    ent_list = []
    for feature in features:
        if feature.ndim == 2:
            if feature.shape[1] == 1000:
                feature = feature[:, :961].view(batch_size, 31, 31)
            elif feature.shape[1] == 10:
                feature = feature[:, :9].view(batch_size, 3, 3)
        num_edges = []
        numpy_array = feature.cpu().numpy()
        convert_array = np.zeros_like(numpy_array)
        for i in range(numpy_array.shape[0]):
            convert_array[i] = convert_to_symmetric(numpy_array[i])
        for i in range(convert_array.shape[0]):
            num_edge = add_edges_to_form_cycle(convert_array[i])
            num_edges.append(num_edge)
        ent = calculate_entropy(num_edges)
        # if ent <= 0.1:
        #     ent = 3.4
        ent_list.append(ent)
    return ent_list

def calculate_entropy(arr):
    # 计算每个元素出现的次数
    counter = Counter(arr)
    total_elements = len(arr)
    entropy = 0
    
    for count in counter.values():
        # 计算每个类别的概率
        probability = count / total_elements
        # 计算信息熵
        entropy -= probability * math.log2(probability)
        
    return entropy

# 将权重矩阵转为对称矩阵，###对角线的值要留着###
def convert_to_symmetric(matrix):
    n = len(matrix)  # 假设矩阵是 n x n 的
    # print(matrix)
    # print("-"*80)
    for i in range(n):
        for j in range(i+1, n):  # 只遍历矩阵的上三角部分（不包括对角线）
            if matrix[i][j] > matrix[j][i]:
                matrix[j][i] = 0
            else:
                matrix[i][j] = 0
    return matrix

class UnionFind:
    def __init__(self, size):
        self.parent = list(range(size))
        self.rank = [0] * size

    def find(self, node):
        if self.parent[node] != node:
            self.parent[node] = self.find(self.parent[node])
        return self.parent[node]

    def union(self, node1, node2):
        root1 = self.find(node1)
        root2 = self.find(node2)
        if root1 != root2:
            if self.rank[root1] > self.rank[root2]:
                self.parent[root2] = root1
            elif self.rank[root1] < self.rank[root2]:
                self.parent[root1] = root2
            else:
                self.parent[root2] = root1
                self.rank[root1] += 1
            return True
        return False
    
def add_edges_to_form_cycle(adj_matrix):
    n = len(adj_matrix)
    edges = []
    for i in range(n):
        for j in range(i + 1, n):
            if adj_matrix[i][j] > 0:  # 仅考虑非零权重的边
                edges.append((adj_matrix[i][j], i, j))
    # 按权重逆序排序
    edges.sort(reverse=True)

    uf = UnionFind(n)
    edges_added = 1
    for weight, u, v in edges:
        if not uf.union(u, v):  # 如果u和v已经在同一个集合中，则添加这条边会形成环
            return edges_added
        edges_added += 1
    return edges_added

# 捕获特征图
def get_feature_maps(model, input_tensor, target_layer):
    # 定义一个局部变量来存储特征图
    feature_maps = {}

    # 定义钩子函数
    def hook_function(module, input, output):
        feature_maps["output"] = output

    # 注册钩子到目标层
    hook = target_layer.register_forward_hook(hook_function)
    
    # 正向传播以计算特征图
    with torch.no_grad():
        model(input_tensor)

    # 移除钩子
    hook.remove()

    # 返回特征图
    return feature_maps["output"]