import torch


def PCB_merge(flat_task_checks, pcb_ratio=0.1):
    all_checks = flat_task_checks.clone()
    n, d = all_checks.shape
    all_checks_abs = clamp(torch.abs(all_checks),
                           min_ratio=0.0001, max_ratio=0.0001)
    clamped_all_checks = torch.sign(all_checks)*all_checks_abs
    self_pcb = normalize(all_checks_abs, 1)**2
    self_pcb_act = torch.exp(n*self_pcb)
    cross_pcb = all_checks * torch.sum(all_checks, dim=0)
    cross_pcb_act = act(cross_pcb)
    task_pcb = self_pcb_act * cross_pcb_act

    scale = normalize(clamp(task_pcb, 1-pcb_ratio, 0), dim=1)
    tvs = clamped_all_checks
    merged_tv = torch.sum(tvs * scale, dim=0) / \
        torch.clamp(torch.sum(scale, dim=0), min=1e-12)
    return merged_tv, clamped_all_checks, scale

def normalize(x, dim=0):
    min_values, _ = torch.min(x, dim=dim, keepdim=True)
    max_values, _ = torch.max(x, dim=dim, keepdim=True)
    # 加上一个极小值防止分母为0
    y = (x - min_values) / (max_values - min_values + 1e-8)
    return y


def clamp(x, min_ratio=0, max_ratio=0):
    # 这个函数用于选取张量中特定百分位范围内的值
    if len(x.size()) == 1:
        d = x.size(0)
        if d == 0:
            return x
        sorted_x, _ = torch.sort(x)
        min_val = sorted_x[int(d * min_ratio)]
        max_val = sorted_x[int(d * (1 - max_ratio) - 1)]
    else:
        d = x.size(1)
        if d == 0:
            return x
        sorted_x, _ = torch.sort(x, dim=1)
        min_val = sorted_x[:, int(d * min_ratio)].unsqueeze(1)
        max_val = sorted_x[:, int(d * (1 - max_ratio) - 1)].unsqueeze(1)

    clamped_x = torch.clamp(x, min_val, max_val)
    return clamped_x


def act(x):
    return torch.tanh(x)

# --- 核心合并函数 ---


def merge_experts_with_pcb(expert_matrices: list, pcb_ratio: float = 0.1) -> torch.Tensor:
    """
    使用PCB-Merging官方代码逻辑，将一组专家矩阵合并成一个基准矩阵。

    Args:
        expert_matrices: 包含所有专家权重矩阵(2D张量)的列表。
        pcb_ratio: 控制在计算合并权重时，选取多大比例的竞争力最强的参数。
                   默认0.1表示选取top 10%。

    Returns:
        一个合并后的基准矩阵 (2D张量)。
    """
    if not expert_matrices:
        raise ValueError("专家矩阵列表不能为空")

    # 记录原始形状和设备
    original_shape = expert_matrices[0].shape
    device = expert_matrices[0].device

    # 1. 将所有专家矩阵展平并堆叠
    # 形状变为 (num_experts, num_params)
    flat_experts = torch.stack([m.flatten()
                               for m in expert_matrices]).to(device)

    # 2. 调用PCB_merge的核心逻辑
    # 注意：官方代码返回三个值，我们只需要第一个，即合并后的向量
    merged_flat_vector, _, _ = PCB_merge(flat_experts, pcb_ratio)

    # 3. 将合并后的向量恢复为原始矩阵形状
    base_matrix = merged_flat_vector.reshape(original_shape)

    return base_matrix
