#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
from numpy import dtype
import torch
from torch import nn
import numpy as np
import math

def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

def FedAvgP(w,args):
    w_avg = torch.zeros(w[0].shape[0], dtype=w[0].dtype, device=args.device)
    for k in w:
        w_avg+=k
    w_avg = torch.div(w_avg, len(w)).detach()
    return w_avg

def FedAvgGradient(grads_list):
    num_client=float(len(grads_list))

    for p0 in grads_list[0][0]['params']:
        p0.grad=torch.div(p0.grad,num_client)
    for i in range(1,len(grads_list)):
        for p0,para in zip(grads_list[0][0]['params'],grads_list[i][0]['params']):
            p0.grad=p0.grad+torch.div(para.grad,num_client)
    # for p0 in grads_list[0][0]['params']:
    #     print(p0.grad)     
    return grads_list[0]


def communication1(args, para, topo):
    para_record_list = []
    for kk in range(args.num_users):
        # 初始化 para_record 为与 para[kk] 结构一致的空列表，并在 args.device 上创建
        para_record = [torch.zeros_like(layer, device=args.device) for layer in para[kk]]

        # 遍历每个用户的参数
        for tt in range(args.num_users):
            for layer_idx, layer in enumerate(para[tt]):
                para_record[layer_idx] += topo[kk][tt] * layer.to(args.device)

        # 确保每一层的参数不参与梯度计算
        para_record = [layer.detach().requires_grad_(True) for layer in para_record]
        para_record_list.append(para_record)
    
    return para_record_list

def communication(args, para, topo):
    para_record_list = []
    for kk in range(args.num_users):
        # 初始化 para_record 为与 para[0] 结构一致的空列表，并在 args.device 上创建
        para_record = [torch.zeros_like(layer, device=args.device) for layer in para[0]]
        
        # 遍历每个用户的参数
        for tt in range(args.num_users):
            for layer_idx, layer in enumerate(para[tt]):
                para_record[layer_idx] += topo[kk][tt] * layer.to(args.device)
        
        # 确保每一层的参数不参与梯度计算
        para_record = [layer.detach().requires_grad_(True) for layer in para_record]
        para_record_list.append(para_record)
    
    return para_record_list


def ComTopo(client, model):
    W = np.zeros((client,client))
    if model == "exp":
        nodeForCom = math.floor(math.log(len(W),2)+1)
        for i in range(len(W)):
            for j in range(nodeForCom):
                W[i][(i+2**j-1)%len(W)] = 1/nodeForCom
    elif model == "circle":
        for i in range(1, client-1):
            W[i][i+1] = 0.5
            W[i][i-1] = 0.5
        W[0][1] = 0.5
        W[client-1][client-2] = 0.5
        W[0][client-1] = 0.5
        W[client-1][0] = 0.5
    elif model == "linear":
        for i in range(1, client-1):
            W[i][i+1] = 0.5
            W[i][i-1] = 0.5
        W[0][1] = 1
        W[client-1][client-2] = 1
    elif model == "circle1":
        for i in range(1, client-1):
            W[i][i+1] = 0.4
            W[i][i-1] = 0.4
        for i in range(client):
            W[i][i] = 0.2
        W[0][1] = 0.4
        W[client-1][client-2] = 0.4
        W[0][client-1] = 0.4
        W[client-1][0] = 0.4
    elif model == "circle2":
        for i in range(client):
            for j in range(-2,3):
                W[i][(i+j)%client] = 1/5
    elif model == "cent":
        for i in range(client):
            for j in range(client):
                W[i][j] = 1/client

    return W