#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch
from torch import nn


def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

def FedAvgpBN(w):
    bn_keys = [k for k in w[0].keys() if 'bn' in k or 'batch_norm' in k]
    
    w_avg = copy.deepcopy(w[0])
    
    non_bn_keys = [k for k in w_avg.keys() if k not in bn_keys]
    for k in non_bn_keys:
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    
    for k in bn_keys:
        del w_avg[k]
        for client_w in w:
            if k in client_w:
                w_avg[k] = client_w[k].clone()
                break
    
    return w_avg

def FedAvg_fsu(w, unlearn_clients, lamda):
    """Weighted aggregation with one or multiple unlearning clients.

    Parameters
    ----------
    w : list[dict]
        Local model state dicts.
    unlearn_clients : int or list[int]
        Index/indices of clients to forget. Accepts a single integer for
        backward compatibility.
    lamda : float
        Weight applied to retained clients during aggregation.
    """

    n = len(w)
    if n == 0:
        raise ValueError("The list of weights is empty.")

    if isinstance(unlearn_clients, (list, tuple)):
        unlearn_set = list(unlearn_clients)
    else:
        unlearn_set = [unlearn_clients]

    if any(u < 0 or u >= n for u in unlearn_set):
        raise IndexError("unlearn_client index out of range.")
    if lamda < 0:
        raise ValueError("Lambda should be non-negative.")

    total_weight = len(unlearn_set) + (n - len(unlearn_set)) * lamda
    if total_weight == 0:
        raise ValueError("Total weight cannot be zero.")

    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        w_avg[k] = 0.0
        for u in unlearn_set:
            w_avg[k] += w[u][k]
        for i in range(n):
            if i not in unlearn_set:
                w_avg[k] += lamda * w[i][k]
        w_avg[k] = torch.div(w_avg[k], total_weight)
    return w_avg

def FedAvg_fsu_frz(w, unlearn_client, lamda, freeze_layers=None):
    n = len(w)
    w_avg = copy.deepcopy(w[0])
    
    for k in w_avg.keys():
        if freeze_layers and k in freeze_layers:
            continue
        
        total_weight = 1 + (n - 1) * lamda
        w_avg[k] = w[unlearn_client][k].clone()
        for i in range(n):
            if i != unlearn_client:
                w_avg[k] += lamda * w[i][k]
        w_avg[k] = torch.div(w_avg[k], total_weight)
    
    return w_avg

def FedAvg_fsu_pBN(w, unlearn_client, lamda):
    n = len(w)
    if n == 0:
        raise ValueError("The list of weights is empty.")
    if unlearn_client < 0 or unlearn_client >= n:
        raise IndexError("unlearn_client index out of range.")
    if lamda < 0:
        raise ValueError("Lambda should be non-negative.")
    total_weight = 1 + (n - 1) * lamda
    if total_weight == 0:
        raise ValueError("Total weight cannot be zero.")
    
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        w_avg[k] = 0.0
        w_avg[k] += w[unlearn_client][k]
        for i in range(n):
            if i != unlearn_client:
                w_avg[k] += lamda * w[i][k]
        w_avg[k] = torch.div(w_avg[k], total_weight)
    return w_avg

def FedAvg_w(w,client_weight):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        w_avg[k] = w[0][k] * client_weight[0]
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] * client_weight[i]
        # w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

def FedAvg_HEAL(w, tau=0.3, beta=0.4):
    """
    改进版 FedAvg，基于 FedHEAL 论文的方法：
      - 计算参数更新一致性 (PUC) 并丢弃不重要的参数
      - 使用动量更新方法动态调整客户端权重
      - 进行加权模型聚合

    输入:
      - w: list，每个元素是客户端上传的模型参数字典（state_dict）

    额外的全局变量（在外部初始化）:
      - FedAvg.L_m: 记录参数一致性（确保为浮点型）
      - FedAvg.epoch_index: 轮次计数
      - FedAvg.device: 计算设备
      - FedAvg.client_weights: 客户端权重 (p_m)，用于加权聚合
    """
    if not hasattr(FedAvg, "device"):
        some_key = next(iter(w[0].keys()))
        FedAvg.device = w[0][some_key].device

    num_clients = len(w)

    if not hasattr(FedAvg, "L_m"):
        FedAvg.L_m = {
            client_idx: { key: torch.zeros_like(param, dtype=torch.float, device=FedAvg.device)
                          for key, param in w[0].items() }
            for client_idx in range(num_clients)
        }
        FedAvg.epoch_index = 0

    if not hasattr(FedAvg, "client_weights"):
        FedAvg.client_weights = {i: 1.0 / num_clients for i in range(num_clients)}

    for client_idx in range(num_clients):
        for key in w[client_idx].keys():
            grad = w[client_idx][key] - FedAvg.L_m[client_idx][key]
            sign_mask = (grad >= 0).float()
            FedAvg.L_m[client_idx][key] = (FedAvg.L_m[client_idx][key] * FedAvg.epoch_index + sign_mask) / (FedAvg.epoch_index + 1)

            if grad.float().mean() >= 0:
                c_m_i = FedAvg.L_m[client_idx][key]
            else:
                c_m_i = 1 - FedAvg.L_m[client_idx][key]

            if c_m_i.mean().item() < tau:
                w[client_idx][key] = torch.zeros_like(w[client_idx][key], device=FedAvg.device)

    FedAvg.epoch_index += 1

    weight_sum = 0.0
    for client_idx in range(num_clients):
        dist = sum((w[client_idx][key] - w[0][key]).pow(2).sum().item() for key in w[client_idx].keys())
        deltap = (1 - beta) * (FedAvg.client_weights[client_idx] - 1.0 / num_clients) + beta * dist
        FedAvg.client_weights[client_idx] = max(0, FedAvg.client_weights[client_idx] + deltap)
        weight_sum += FedAvg.client_weights[client_idx]

    for client_idx in range(num_clients):
        FedAvg.client_weights[client_idx] /= weight_sum

    w_avg = {}
    for key in w[0].keys():
        w_avg[key] = sum(FedAvg.client_weights[i] * w[i][key] for i in range(num_clients))

    return w_avg

def FedAvg_salun(w_glob, mask, delta_history_unlearn, unlearn_lr=0.1):
    avg_unlearn_delta = {k: torch.stack([d[k] for d in delta_history_unlearn]).mean(dim=0) 
                        for k in delta_history_unlearn[0]}
    
    for k in w_glob:
        if k in avg_unlearn_delta:
            w_glob[k] = w_glob[k] - unlearn_lr * mask[k] * avg_unlearn_delta[k]
    
    return w_glob


def fedavg(local_models):

    """
    Parameters
    ----------
    local_models : list of local models
        DESCRIPTION.In federated learning, with the global_model as the initial model, each user uses a collection of local models updated with their local data.
    local_model_weights : tensor or array
        DESCRIPTION. The weight of each local model is usually related to the accuracy rate and number of data of the local model.(Bypass)

    Returns
    -------
    update_global_model
        Updated global model using fedavg algorithm
    """
    # N = len(local_models)
    # new_global_model = copy.deepcopy(local_models[0])
    # print(len(local_models))
    global_model = copy.deepcopy(local_models[0])
    avg_state_dict = global_model.state_dict()

    local_state_dicts = list()
    for model in local_models:
        local_state_dicts.append(model.state_dict())

    for layer in avg_state_dict.keys():
        avg_state_dict[layer] *= 0
        for client_idx in range(len(local_models)):
            avg_state_dict[layer] += local_state_dicts[client_idx][layer]
        avg_state_dict[layer] = avg_state_dict[layer].type(torch.float32)
        avg_state_dict[layer] /= len(local_models)


    global_model.load_state_dict(avg_state_dict)
    return global_model


if __name__ == '__main__':
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 2)
            self.fc3 = nn.Linear(3, 3)

    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return torch.nn.functional.log_softmax(x)


    net1 = Net()
    w1 = net1.state_dict()
    net2 = Net()
    w2 = net2.state_dict()
    net3 = Net()
    w3 = net3.state_dict()
    ws =[w1,w2,w3]

    print(w1)
    print("-" * 50)
    print(w2)
    print("-" * 50)
    print(w3)
    print("-" * 50)

    weight = [0.2,0.3,0.5]
    w_a = FedAvg_w(ws,weight)
    print(w_a)