from libs import *
from dataset import *


def fedavg(avg_model, client_gradients, learning_rate):

    averaged_gradients = {}
    for key in client_gradients[0].keys():
        grads = [grads[key] for grads in client_gradients]
        averaged_gradients[key] = torch.mean(torch.stack(grads), dim=0)

    # Manually update global model parameters (alternative to optimizer)
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in averaged_gradients:
                param -= learning_rate * averaged_gradients[name]
    return avg_model


def h_public(avg_model, client_gradients, learning_rate, attackers):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)

    d = flat_grads.shape[1]
    b = len(client_gradients)
    niters = 3
    rho = 0.1
    tau = 100
    n = len(client_gradients)
    I_good = []  # Store candidate benign indices for each iteration

    for _ in range(niters):
        # Randomly select b dimensions (no duplication)
        r = torch.randperm(d, device=device)[:b].sort().values
        # r = r.cpu().numpy()  # Convert to numpy index

        # Subsampled gradients (the actual implementation needs to retain the original gradient structure)
        subsampled_grads = flat_grads[:, r]  # [n_clients, b]

        m_i = torch.abs(subsampled_grads[0])

        # Calculate the similarity matrix
        similarity_matrix = torch.zeros(n, device=device)
        for i in range(n):
            numerator = torch.sum(
                m_i / (torch.abs(subsampled_grads[i] - subsampled_grads[0]) + m_i)
            )
            similarity_matrix[i] = numerator / b

        # Calculating anomaly score
        scores = similarity_matrix - rho * torch.maximum(
            torch.linalg.norm(subsampled_grads, dim=1),
            tau / torch.linalg.norm(subsampled_grads, dim=1),
        )

        k = int(n - attackers - 4)
        _, selected_indices = torch.topk(scores, k=k)
        I = selected_indices.tolist()

        I_good.append(set(I))

    # Calculate the final index intersection
    if len(I_good) == 0:
        raise ValueError("No benign clients selected in iterations")
    I_final = set.intersection(*I_good)

    # If the intersection is empty, it is downgraded to the union of the last step.
    if not I_final:
        I_final = set.union(*I_good)

    # Aggregate the final selected gradient
    final_grads = torch.mean(flat_grads[list(I_final), :], dim=0)

    # Restore the gradient dictionary structure
    aggregated_updates = {}
    pointer = 0
    for name, param in client_gradients[0].items():
        numel = param.numel()
        aggregated_updates[name] = final_grads[pointer : pointer + numel].view_as(param)
        pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model

def zeno(avg_model, client_gradients, learning_rate):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)

    rho = 0.2
    epsilon = 0.1
    normlize = flat_grads[0] / torch.linalg.norm(flat_grads[0])
    n = len(client_gradients)
    final_grads = torch.zeros_like(flat_grads[0])

    for i in range(n - 1):

        v = flat_grads[i + 1] / torch.linalg.norm(flat_grads[i + 1])
        score = torch.dot(v, normlize) - rho + epsilon
        if score > 0:
            final_grads = flat_grads[i + 1]
            break

    # Restore the gradient dictionary structure
    aggregated_updates = {}
    pointer = 0
    for name, param in client_gradients[0].items():
        numel = param.numel()
        aggregated_updates[name] = final_grads[pointer : pointer + numel].view_as(param)
        pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model

def trust(avg_model, client_gradients, learning_rate):

    # --------------------- 步骤1：处理第0个客户端的梯度 \( \boldsymbol{g}_0 \) ---------------------
    g0_flat = []
    for key in client_gradients[0].keys():
        g0_flat.append(client_gradients[0][key].flatten())  # 展平每个参数的梯度
    g0_vector = torch.cat(g0_flat)  # 拼接为完整梯度向量
    norm_g0 = torch.norm(g0_vector)  # 计算 \( \|\boldsymbol{g}_0\| \)
    norm_g0 = norm_g0 if norm_g0 > 0 else torch.tensor(1e-8)  # 防止除零

    # --------------------- 步骤2：对每个客户端计算 \( c_i \) 和权重 ---------------------
    client_c_values = []
    client_weights = []

    for i in range(len(client_gradients)):
        # 展平第i个客户端的梯度 \( \boldsymbol{g}_i \)
        gi_flat = []
        for key in client_gradients[i].keys():
            gi_flat.append(client_gradients[i][key].flatten())
        gi_vector = torch.cat(gi_flat)

        # 计算 \( \|\boldsymbol{g}_i\| \)
        norm_gi = torch.norm(gi_vector)
        norm_gi = norm_gi if norm_gi > 0 else torch.tensor(1e-8)  # 防止除零

        # 计算内积 \( \langle \boldsymbol{g}_i, \boldsymbol{g}_0 \rangle \)
        inner_product = torch.dot(gi_vector, g0_vector)

        # 计算余弦相似度 \( c_i \)
        c_i = inner_product / (norm_gi * norm_g0)
        client_c_values.append(c_i.item())  # 保存为Python数值

        # 计算权重：\( \text{ReLU}(c_i) \cdot \frac{\|\boldsymbol{g}_0\|}{\|\boldsymbol{g}_i\|} \)
        relu_c_i = torch.relu(torch.tensor(c_i))
        weight_i = relu_c_i * (norm_g0 / norm_gi)
        client_weights.append(weight_i)

    # --------------------- 步骤3：计算分母 \( \sum_{j=1}^n \text{ReLU}(c_j) \) ---------------------
    denominator = torch.sum(torch.relu(torch.tensor(client_c_values)))
    denominator = denominator if denominator > 0 else torch.tensor(1e-8)  # 防止除零

    # --------------------- 步骤4：对每个模型参数进行加权聚合 ---------------------
    averaged_gradients = {}
    for key in client_gradients[0].keys():
        grads = [grads[key] for grads in client_gradients]  # 收集所有客户端的该参数梯度
        weighted_sum = torch.zeros_like(grads[0])  # 初始化加权和（与梯度同形状）

        # 逐客户端加权累加
        for i in range(len(grads)):
            weighted_sum += client_weights[i] * grads[i]

        # 除以分母得到聚合梯度
        averaged_gradients[key] = weighted_sum / denominator

    # --------------------- 步骤5：手动更新全局模型参数 ---------------------
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in averaged_gradients:
                param -= learning_rate * averaged_gradients[name]
    return avg_model

def h_median(avg_model, client_gradients, learning_rate, attackers):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)
    comp_grad = torch.median(flat_grads, dim=0)[0]

    d = flat_grads.shape[1]
    b = len(client_gradients)
    niters = 3
    rho = 10
    tau = 0.1
    n = len(client_gradients)
    I_good = []  # Store candidate benign indices for each iteration

    for _ in range(niters):
        # Randomly select b dimensions (no duplication)
        r = torch.randperm(d, device=device)[:b].sort().values
        # r = r.cpu().numpy()  # Convert to numpy index

        # Subsampled gradients (the actual implementation needs to retain the original gradient structure)
        subsampled_grads = flat_grads[:, r]  # [n_clients, b]
        subsampled_comp_grads = comp_grad[r]

        # Calculate the mean m_i of each column
        m_i = torch.abs(subsampled_comp_grads)

        # Calculate the similarity matrix
        similarity_matrix = torch.zeros(n, device=device)
        for i in range(n):
            numerator = torch.sum(
                m_i / (torch.abs(subsampled_comp_grads - subsampled_grads[i]) + m_i)
            )
            similarity_matrix[i] = numerator / b

        # Calculating anomaly score
        scores = similarity_matrix - rho * torch.maximum(
            torch.linalg.norm(subsampled_grads, dim=1),
            tau / torch.linalg.norm(subsampled_grads, dim=1),
        )

        # Select the (n - c*m) indexes with the lowest scores
        k = int(n - attackers)
        _, selected_indices = torch.topk(scores, k=k)
        I = selected_indices.tolist()

        I_good.append(set(I))

    # Calculate the final index intersection
    if len(I_good) == 0:
        raise ValueError("No benign clients selected in iterations")
    I_final = set.intersection(*I_good)

    # If the intersection is empty, it is downgraded to the union of the last step.
    if not I_final:
        I_final = set.union(*I_good)

    # Aggregate the final selected gradient
    final_grads = torch.mean(flat_grads[list(I_final), :], dim=0)

    # Restore the gradient dictionary structure
    aggregated_updates = {}
    pointer = 0
    for name, param in client_gradients[0].items():
        numel = param.numel()
        aggregated_updates[name] = final_grads[pointer : pointer + numel].view_as(param)
        pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model


def h_krum(avg_model, client_gradients, learning_rate, attackers):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)
    comp_grad = krum_al(flat_grads, attackers)

    d = flat_grads.shape[1]
    b = len(client_gradients)
    niters = 3
    rho = 0.1
    tau = 100
    n = len(client_gradients)
    I_good = []  # Store candidate benign indices for each iteration

    for _ in range(niters):
        # Randomly select b dimensions (no duplication)
        r = torch.randperm(d, device=device)[:b].sort().values
        # r = r.cpu().numpy()  # Convert to numpy index

        # Subsampled gradients (the actual implementation needs to retain the original gradient structure)
        subsampled_grads = flat_grads[:, r]  # [n_clients, b]
        subsampled_comp_grads = comp_grad[r]

        # Calculate the mean m_i of each column
        m_i = torch.abs(subsampled_comp_grads)

        # Calculate the similarity matrix
        similarity_matrix = torch.zeros(n, device=device)
        for i in range(n):
            numerator = torch.sum(
                m_i / (torch.abs(subsampled_comp_grads - subsampled_grads[i]) + m_i)
            )
            similarity_matrix[i] = numerator / b

        # Calculating anomaly score
        scores = similarity_matrix - rho * torch.maximum(
            torch.linalg.norm(subsampled_grads, dim=1),
            tau / torch.linalg.norm(subsampled_grads, dim=1),
        )

        # Select the (n - c*m) indexes with the lowest scores
        k = int(n - attackers)
        _, selected_indices = torch.topk(scores, k=k)
        I = selected_indices.tolist()

        I_good.append(set(I))

    # Calculate the final index intersection
    if len(I_good) == 0:
        raise ValueError("No benign clients selected in iterations")
    I_final = set.intersection(*I_good)

    # If the intersection is empty, it is downgraded to the union of the last step.
    if not I_final:
        I_final = set.union(*I_good)

    # Aggregate the final selected gradient
    final_grads = torch.mean(flat_grads[list(I_final), :], dim=0)

    # Restore the gradient dictionary structure
    aggregated_updates = {}
    pointer = 0
    for name, param in client_gradients[0].items():
        numel = param.numel()
        aggregated_updates[name] = final_grads[pointer : pointer + numel].view_as(param)
        pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model


def h_gm(avg_model, client_gradients, learning_rate, attackers):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)
    comp_grad = gm(flat_grads)

    d = flat_grads.shape[1]
    b = len(client_gradients)
    niters = 3
    rho = 0.1
    tau = 100
    n = len(client_gradients)
    I_good = []  # Store candidate benign indices for each iteration

    for _ in range(niters):
        # Randomly select b dimensions (no duplication)
        r = torch.randperm(d, device=device)[:b].sort().values
        # r = r.cpu().numpy()  # Convert to numpy index

        # Subsampled gradients (the actual implementation needs to retain the original gradient structure)
        subsampled_grads = flat_grads[:, r]  # [n_clients, b]
        subsampled_comp_grads = comp_grad[r]

        # Calculate the mean m_i of each column
        m_i = torch.abs(subsampled_comp_grads)

        # Calculate the similarity matrix
        similarity_matrix = torch.zeros(n, device=device)
        for i in range(n):
            numerator = torch.sum(
                m_i / (torch.abs(subsampled_comp_grads - subsampled_grads[i]) + m_i)
            )
            similarity_matrix[i] = numerator / b

        # Calculating anomaly score
        scores = similarity_matrix - rho * torch.maximum(
            torch.linalg.norm(subsampled_grads, dim=1),
            tau / torch.linalg.norm(subsampled_grads, dim=1),
        )

        # Select the (n - c*m) indexes with the lowest scores
        k = int(n - attackers)
        _, selected_indices = torch.topk(scores, k=k)
        I = selected_indices.tolist()

        I_good.append(set(I))

    # Calculate the final index intersection
    if len(I_good) == 0:
        raise ValueError("No benign clients selected in iterations")
    I_final = set.intersection(*I_good)

    # If the intersection is empty, it is downgraded to the union of the last step.
    if not I_final:
        I_final = set.union(*I_good)

    # Aggregate the final selected gradient
    final_grads = torch.mean(flat_grads[list(I_final), :], dim=0)

    # Restore the gradient dictionary structure
    aggregated_updates = {}
    pointer = 0
    for name, param in client_gradients[0].items():
        numel = param.numel()
        aggregated_updates[name] = final_grads[pointer : pointer + numel].view_as(param)
        pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model


def h_mca(avg_model, client_gradients, learning_rate, attackers):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)
    comp_grad = mca_al(flat_grads)

    d = flat_grads.shape[1]
    b = len(client_gradients)
    niters = 3
    rho = 0.1
    tau = 100
    n = len(client_gradients)
    I_good = []  # Store candidate benign indices for each iteration

    for _ in range(niters):
        # Randomly select b dimensions (no duplication)
        r = torch.randperm(d, device=device)[:b].sort().values
        # r = r.cpu().numpy()  # Convert to numpy index

        # Subsampled gradients (the actual implementation needs to retain the original gradient structure)
        subsampled_grads = flat_grads[:, r]  # [n_clients, b]
        subsampled_comp_grads = comp_grad[r]

        # Calculate the mean m_i of each column
        m_i = torch.abs(subsampled_comp_grads)

        # Calculate the similarity matrix
        similarity_matrix = torch.zeros(n, device=device)
        for i in range(n):
            numerator = torch.sum(
                m_i / (torch.abs(subsampled_comp_grads - subsampled_grads[i]) + m_i)
            )
            similarity_matrix[i] = numerator / b

        # Calculating anomaly score
        scores = similarity_matrix - rho * torch.maximum(
            torch.linalg.norm(subsampled_grads, dim=1),
            tau / torch.linalg.norm(subsampled_grads, dim=1),
        )

        # Select the (n - c*m) indexes with the lowest scores
        k = int(n - attackers)
        _, selected_indices = torch.topk(scores, k=k)
        I = selected_indices.tolist()

        I_good.append(set(I))

    # Calculate the final index intersection
    if len(I_good) == 0:
        raise ValueError("No benign clients selected in iterations")
    I_final = set.intersection(*I_good)

    # If the intersection is empty, it is downgraded to the union of the last step.
    if not I_final:
        I_final = set.union(*I_good)

    # Aggregate the final selected gradient
    final_grads = torch.mean(flat_grads[list(I_final), :], dim=0)

    # Restore the gradient dictionary structure
    aggregated_updates = {}
    pointer = 0
    for name, param in client_gradients[0].items():
        numel = param.numel()
        aggregated_updates[name] = final_grads[pointer : pointer + numel].view_as(param)
        pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model


def h_cclip(avg_model, client_gradients, learning_rate, attackers, previous_v=None):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)
    comp_grad = cclip_al(flat_grads, previous_v)

    d = flat_grads.shape[1]
    b = len(client_gradients)
    niters = 3
    rho = 0.1
    tau = 100
    n = len(client_gradients)
    I_good = []  # Store candidate benign indices for each iteration

    for _ in range(niters):
        # Randomly select b dimensions (no duplication)
        r = torch.randperm(d, device=device)[:b].sort().values
        # r = r.cpu().numpy()  # Convert to numpy index

        # Subsampled gradients (the actual implementation needs to retain the original gradient structure)
        subsampled_grads = flat_grads[:, r]  # [n_clients, b]
        subsampled_comp_grads = comp_grad[r]

        # Calculate the mean m_i of each column
        m_i = torch.abs(subsampled_comp_grads)

        # Calculate the similarity matrix
        similarity_matrix = torch.zeros(n, device=device)
        for i in range(n):
            numerator = torch.sum(
                m_i / (torch.abs(subsampled_comp_grads - subsampled_grads[i]) + m_i)
            )
            similarity_matrix[i] = numerator / b

        # Calculating anomaly score
        scores = similarity_matrix - rho * torch.maximum(
            torch.linalg.norm(subsampled_grads, dim=1),
            tau / torch.linalg.norm(subsampled_grads, dim=1),
        )

        # Select the (n - c*m) indexes with the lowest scores
        k = int(n - attackers)
        _, selected_indices = torch.topk(scores, k=k)
        I = selected_indices.tolist()

        I_good.append(set(I))

    # Calculate the final index intersection
    if len(I_good) == 0:
        raise ValueError("No benign clients selected in iterations")
    I_final = set.intersection(*I_good)

    # If the intersection is empty, it is downgraded to the union of the last step.
    if not I_final:
        I_final = set.union(*I_good)

    # Aggregate the final selected gradient
    final_grads = torch.mean(flat_grads[list(I_final), :], dim=0)

    # Restore the gradient dictionary structure
    aggregated_updates = {}
    pointer = 0
    for name, param in client_gradients[0].items():
        numel = param.numel()
        aggregated_updates[name] = final_grads[pointer : pointer + numel].view_as(param)
        pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model, final_grads


def gm(client_gradients):
    median = torch.median(client_gradients, dim=0)[0]

    for i in range(1000):

        distances = torch.norm(client_gradients - median, dim=1)
        inv_distances = torch.where(
            distances < 1e-6, torch.zeros_like(distances), 1 / distances
        )
        weighted_grads = client_gradients * inv_distances.unsqueeze(1)
        new_median = torch.sum(weighted_grads, dim=0) / torch.sum(inv_distances)

        if torch.norm(new_median - median) < 1e-5:
            break
        median = new_median

    return new_median


def raga(avg_model, client_gradients, learning_rate):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)

    # median = torch.median(flat_grads, dim=0)[0]

    # for i in range(1000):

    #     distances = torch.norm(flat_grads - median, dim=1)
    #     inv_distances = torch.where(
    #         distances < 1e-6, torch.zeros_like(distances), 1 / distances
    #     )
    #     weighted_grads = flat_grads * inv_distances.unsqueeze(1)
    #     new_median = torch.sum(weighted_grads, dim=0) / torch.sum(inv_distances)

    #     if torch.norm(new_median - median) < 1e-5:
    #         break
    #     median = new_median

    new_median = gm(flat_grads)

    median_grads = {}
    pointer = 0
    for name, param in avg_model.named_parameters():
        if param.requires_grad:
            numel = param.numel()
            median_grads[name] = new_median[pointer : pointer + numel].view_as(param)
            pointer += numel

    # Manually update global model parameters (alternative to optimizer)
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in median_grads:
                param -= learning_rate * median_grads[name]
    return avg_model


def rfa(avg_model, client_params, learning_rate):

    flat_params = []

    for grad_dict in client_params:
        flat_param = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_params.append(flat_param)

    flat_params = torch.stack(flat_params)

    # median = torch.median(flat_params, dim=0)[0]

    # for i in range(1000):

    #     distances = torch.norm(flat_params - median, dim=1)
    #     inv_distances = torch.where(
    #         distances < 1e-6, torch.zeros_like(distances), 1 / distances
    #     )
    #     weighted_params = flat_params * inv_distances.unsqueeze(1)
    #     new_median = torch.sum(weighted_params, dim=0) / torch.sum(inv_distances)

    #     if torch.norm(new_median - median) < 1e-5:
    #         break
    #     median = new_median

    new_median = gm(flat_params)

    median_grads = {}
    pointer = 0
    for name, param in avg_model.named_parameters():
        if param.requires_grad:
            numel = param.numel()
            median_grads[name] = new_median[pointer : pointer + numel].view_as(param)
            pointer += numel

    # Manually update global model parameters (alternative to optimizer)
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in median_grads:
                param.data = median_grads[name].data.clone()
    return avg_model


def median(avg_model, client_gradients, learning_rate):

    averaged_gradients = {}
    for key in client_gradients[0].keys():
        grads = [grads[key] for grads in client_gradients]
        averaged_gradients[key] = torch.median(torch.stack(grads), dim=0)[0]

    # Manually update global model parameters (alternative to optimizer)
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in averaged_gradients:
                param -= learning_rate * averaged_gradients[name]
    return avg_model


def krum(avg_model, client_gradients, learning_rate, q):
    num_clients = len(client_gradients)
    # Initialize the distance matrix
    distance_matrix = torch.zeros((num_clients, num_clients))
    # Initialize the score of each client
    scores = torch.zeros(num_clients)

    # Calculate the distance between each pair of client gradients
    for i in range(num_clients):
        for j in range(num_clients):
            if i != j:
                # Calculate the second norm of the gradient difference
                grad_diff = 0
                for key in client_gradients[i].keys():
                    grad_diff += torch.sum(
                        (client_gradients[i][key] - client_gradients[j][key]) ** 2
                    )
                distance_matrix[i][j] = torch.sqrt(grad_diff)

    # Calculate the score for each client
    for i in range(num_clients):
        sorted_distances = torch.sort(distance_matrix[i])[0]
        # Take the first n - q - 1 minimum distances and sum them
        scores[i] = torch.sum(sorted_distances[: num_clients - q - 1])

    # Select the gradient of the client with the smallest score
    selected_index = torch.argmin(scores)
    selected_gradient = client_gradients[selected_index]

    # Manually update global model parameters (alternative to optimizer)
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in selected_gradient:
                param -= learning_rate * selected_gradient[name]

    return avg_model


def krum_al(client_gradients, q):

    n = len(client_gradients)
    k = n - q - 2

    distances = torch.zeros(n, n)
    for i in range(n):
        for j in range(i + 1, n):
            dist = torch.norm(client_gradients[i] - client_gradients[j]) ** 2
            distances[i, j] = dist
            distances[j, i] = dist

    scores = []
    for i in range(n):
        # Get the distance from gradient i to all other gradients and sort them
        dists = distances[i].clone()
        dists[i] = float("inf")  # 排除自身
        sorted_dists, _ = torch.sort(dists)

        # Calculate the sum of the first k+1 minimum distances
        score = torch.sum(sorted_dists[: k + 1])
        scores.append(score)

    # Select the gradient with the smallest score (closest to other gradients)
    selected_idx = torch.argmin(torch.tensor(scores))
    return client_gradients[selected_idx]


def mca_al(client_gradients):
    median = torch.mean(client_gradients, dim=0)
    sigma = 10 * torch.median(torch.linalg.norm(client_gradients - median, dim=1)) ** 2

    for i in range(1000):

        weights = torch.exp(
            -torch.linalg.norm(client_gradients - median, dim=1) ** 2 / sigma
        )
        weighted_sum = torch.sum(weights.unsqueeze(1) * client_gradients, dim=0)
        new_median = weighted_sum / torch.sum(weights)

        if torch.norm(new_median - median) < 1e-5:
            break
        median = new_median

    return new_median


def mca(avg_model, client_gradients, learning_rate):

    flat_grads = []

    for grad_dict in client_gradients:
        flat_grad = torch.cat([g.view(-1) for g in grad_dict.values()])
        flat_grads.append(flat_grad)

    flat_grads = torch.stack(flat_grads)

    # median = torch.mean(flat_grads, dim=0)
    # sigma = 10 * torch.median(torch.linalg.norm(flat_grads - median, dim=1)) ** 2

    # for i in range(1000):

    #     weights = torch.exp(-torch.linalg.norm(flat_grads - median, dim=1) ** 2 / sigma)
    #     weighted_sum = torch.sum(weights.unsqueeze(1) * flat_grads, dim=0)
    #     new_median = weighted_sum / torch.sum(weights)

    #     if torch.norm(new_median - median) < 1e-5:
    #         break
    #     median = new_median

    new_median = mca_al(flat_grads)

    median_grads = {}
    pointer = 0
    for name, param in avg_model.named_parameters():
        if param.requires_grad:
            numel = param.numel()
            median_grads[name] = new_median[pointer : pointer + numel].view_as(param)
            pointer += numel

    # Manually update global model parameters (alternative to optimizer)
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in median_grads:
                param -= learning_rate * median_grads[name]
    return avg_model


def cclip_al(client_gradients, previous_v=None):

    tau = 1.0  # Fixed clipping threshold τ=1
    L = 1  # Fixed number of iterations L=1

    # Initialize v (using previous_v or FedAvg result)
    if previous_v is None:
        v = torch.zeros(client_gradients.shape[1]).to(device)
    else:
        v = previous_v.clone()

    # AGG algorithm iteration (L=1 times)
    for _ in range(L):
        diffs = client_gradients - v.unsqueeze(0)  # Calculate m_i - v
        norms = torch.linalg.norm(diffs, dim=1, keepdim=True)  # Calculate m_i - v

        # Calculate the clipping factor (when τ=1, min(1, 1/||m_i - v||))
        clip_factors = torch.minimum(
            torch.tensor(1.0, device=v.device), tau / (norms + 1e-10)  # Avoid division by zero
        )

        # The amount of update after clipping c_i = (m_i - v) * clip_factor
        c_i = diffs * clip_factors

        # Update v: v = v + avg(c_i)
        avg_c = torch.mean(c_i, dim=0)
        v += avg_c

    return v


def cclip(avg_model, client_gradients, learning_rate, previous_v=None):

    tau = 1.0  # Fixed clipping threshold τ=1
    L = 1  # Fixed number of iterations L=1

    # Flatten all parameter updates for vector operations
    flat_updates = []
    for update in client_gradients:
        flat = torch.cat([t.view(-1) for t in update.values()])
        flat_updates.append(flat)
    flat_updates = torch.stack(flat_updates)  # [n_clients, total_params]

    # Initialize v (using previous_v or FedAvg result)
    if previous_v is None:
        v = torch.zeros(flat_updates.shape[1]).to(device)
    else:
        v = previous_v.clone()

    # AGG algorithm iteration (L=1 times)
    for _ in range(L):
        diffs = flat_updates - v.unsqueeze(0)  # Calculate m_i - v
        norms = torch.linalg.norm(diffs, dim=1, keepdim=True)  # Calculate m_i - v

        # Calculate the clipping factor (when τ=1, min(1, 1/||m_i - v||))
        clip_factors = torch.minimum(
            torch.tensor(1.0, device=v.device), tau / (norms + 1e-10)  # Avoid division by zero
        )

        # The amount of update after clipping c_i = (m_i - v) * clip_factor
        c_i = diffs * clip_factors

        # Update v: v = v + avg(c_i)
        avg_c = torch.mean(c_i, dim=0)
        v += avg_c

    # Restore the flattened v to the model parameter structure
    aggregated_updates = {}
    pointer = 0
    for name, param in avg_model.named_parameters():
        if param.requires_grad:
            numel = param.numel()
            aggregated_updates[name] = v[pointer : pointer + numel].view_as(param)
            pointer += numel

    # Apply updates to model parameters
    with torch.no_grad():
        for name, param in avg_model.named_parameters():
            if name in aggregated_updates:
                param -= learning_rate * aggregated_updates[name]

    return avg_model, v

def evaluate_global_model(data_x, data_y, model, dataset_name):
    model.eval()
    model = model.to(device)
    correct = 0
    loss_overall = 0
    criterion = nn.CrossEntropyLoss()
    if dataset_name == "CIFAR100":
        batch_size = min(200, data_x.shape[0])
    else:
        batch_size = min(1000, data_x.shape[0])
    n_test = data_x.shape[0]
    test_gen = data.DataLoader(
        Dataset(data_x, data_y, dataset_name=dataset_name),
        batch_size=batch_size,
        shuffle=False,
    )

    with torch.no_grad():
        test_gen_iter = test_gen.__iter__()
        for i in range(int(cp.ceil(n_test / batch_size))):
            batch_x, batch_y = test_gen_iter.__next__()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y.reshape(-1).long())
            loss_overall += loss.item() * batch_x.size(0)

            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == batch_y.reshape(-1).long()).sum().item()

    return loss_overall / n_test, correct / n_test


def train_grad(
    avg_model, train_x, train_y, learning_rate, batch_size, epoch, dataset_name
):

    train_gen = data.DataLoader(
        Dataset(train_x, train_y, train=True, dataset_name=dataset_name),
        batch_size=batch_size,
        shuffle=True,
    )
    model = copy.deepcopy(avg_model)

    model.train()
    model = model.to(device)

    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate
    )
    optimizer.zero_grad()

    criterion = nn.CrossEntropyLoss()

    for e in range(epoch):

        train_gen_iter = train_gen.__iter__()

        for i in range(int(cp.ceil(train_x.shape[0] / batch_size))):

            batch_x, batch_y = train_gen_iter.__next__()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            outputs = model(batch_x)
            loss = criterion(outputs, batch_y.reshape(-1).long())

            loss.backward()

    gradients = {
        name: param.grad.clone().detach()
        for name, param in model.named_parameters()
        if param.requires_grad
    }

    # del model

    # torch.cuda.empty_cache()

    return gradients


def train_param(
    avg_model, train_x, train_y, learning_rate, batch_size, epoch, dataset_name
):

    train_gen = data.DataLoader(
        Dataset(train_x, train_y, train=True, dataset_name=dataset_name),
        batch_size=batch_size,
        shuffle=True,
    )
    model = copy.deepcopy(avg_model)

    model.train()
    model = model.to(device)

    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate
    )
    optimizer.zero_grad()

    criterion = nn.CrossEntropyLoss()

    for e in range(epoch):

        train_gen_iter = train_gen.__iter__()

        for i in range(int(cp.ceil(train_x.shape[0] / batch_size))):

            batch_x, batch_y = train_gen_iter.__next__()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            outputs = model(batch_x)
            loss = criterion(outputs, batch_y.reshape(-1).long())

            loss.backward()

        optimizer.step()

    params = {
        name: param for name, param in model.named_parameters() if param.requires_grad
    }

    # del model

    # torch.cuda.empty_cache()

    return params
