from typing import Union, Generator
from copy import deepcopy
import random
import pynvml
import numpy as np
import torch
def disp_num_params(model):
    total_param_in_use = 0
    total_all_param = 0
    for layer, layer_prefx in zip(model.prunable_layers, model.prunable_layer_prefixes):
        layer_param_in_use = layer.num_weight
        layer_all_param = layer.mask.nelement()
        total_param_in_use += layer_param_in_use
        total_all_param += layer_all_param
        # print("{} remaining: {}/{} = {}".format(layer_prefx, layer_param_in_use, layer_all_param,
        #                                         layer_param_in_use / layer_all_param))
    print("Total: {}/{} = {}".format(total_param_in_use, total_all_param, total_param_in_use / total_all_param))

    return total_param_in_use / total_all_param, total_param_in_use

def select_best_gpu(min_memory=11 * 1024):  # min_memory 以 MB 为单位，默认 11GB
    pynvml.nvmlInit()
    num_gpus = pynvml.nvmlDeviceGetCount()

    best_gpu = None
    max_free_memory = 0

    for i in range(num_gpus):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        free_memory = mem_info.free // 1024 ** 2  # 转换为 MB

        print(f"GPU {i}: Free memory: {free_memory} MB")

        if free_memory > max_free_memory and free_memory >= min_memory:
            max_free_memory = free_memory
            best_gpu = i

    pynvml.nvmlShutdown()

    if best_gpu is None:
        raise RuntimeError(f"No GPU found with at least {min_memory / 1024} GB free memory!")

    print(f"Selected GPU {best_gpu} with {max_free_memory} MB free memory.")
    return best_gpu

def copy_dict(ori_dict: Union[dict, Generator]):
    generator = ori_dict.items() if isinstance(ori_dict, dict) else ori_dict
    copied_dict = dict()
    for key, param in generator:
        copied_dict[key] = param
    return copied_dict


def deepcopy_dict(ori_dict: Union[dict, Generator]):
    generator = ori_dict.items() if isinstance(ori_dict, dict) else ori_dict
    deepcopied_dict = dict()
    for key, param in generator:
        deepcopied_dict[key] = param.clone()
    return deepcopied_dict


def copy_shuffle_list(inp_list):
    copy_list = deepcopy(inp_list)
    random.shuffle(copy_list)
    return copy_list


def dirichlet_split_noniid(train_labels, alpha, n_clients):
    '''
    按照参数为alpha的Dirichlet分布将样本索引集合划分为n_clients个子集
    '''

    n_classes = train_labels.max()+1
    # (K, N) 类别标签分布矩阵X，记录每个类别划分到每个client去的比例
    label_distribution = np.random.dirichlet([alpha]*n_clients, n_classes)
    # (K, ...) 记录K个类别对应的样本索引集合
    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]

    # 记录N个client分别对应的样本索引集合
    client_idcs = [[] for _ in range(n_clients)]
    for k_idcs, fracs in zip(class_idcs, label_distribution):
        # np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
        # i表示第i个client，idcs表示其对应的样本索引集合idcs
        for i, idcs in enumerate(np.split(k_idcs,
                                          (np.cumsum(fracs)[:-1]*len(k_idcs)).
                                          astype(int))):
            client_idcs[i] += [idcs]

    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]
    client_idcs = sorted(client_idcs, key=lambda x: (len(x)), reverse=False)

    return client_idcs

def compute_same_params_ratio(model1, model2, atol=1e-5):
    state_dict1 = model1.state_dict()
    state_dict2 = model2.state_dict()

    same_elements = 0
    total_elements = 0

    for key in state_dict1:
        param1 = state_dict1[key]
        param2 = state_dict2[key]

        if param1.shape != param2.shape:
            raise ValueError(f"参数 {key} 的形状不匹配")

        # 将参数拉平后进行比较
        total_elements += param1.numel()
        # 对于浮点数，建议使用 isclose 来比较（考虑数值精度）
        same = torch.isclose(param1.view(-1), param2.view(-1), atol=atol)
        same_elements += same.sum().item()

    ratio = same_elements / total_elements
    return ratio