#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import copy
import torch
from torch import nn
import networkx as nx
import random
import numpy as np
import pdb

random.seed(886)
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg


def FedAvg_ourpre(w, new_sets):

    w_avg = copy.deepcopy(w[new_sets[0]])
    for k in w_avg.keys():
        for i in new_sets[1:]:
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(new_sets))
    return w_avg



def FedGraph(all_client_w, global_w, num_client, edge_prob):

    G = nx.fast_gnp_random_graph(len(num_client), edge_prob, seed=None, directed=False)
    # graph

    for i in num_client:
        for k in all_client_w[i].keys():
            all_client_w[i][k] -= global_w[k]

    rand_num = random.randint(1,len(num_client)-1)
    np.random.shuffle(num_client)
    selected_client = num_client[:rand_num]
    w_collect = []

    for i in selected_client:
        w_curr = copy.deepcopy(all_client_w[i])
        adj_list = list(G.adj[i])
        for k in w_curr.keys():
            for j in adj_list:
                w_curr[k] += all_client_w[j][k]
                # w_curr[k] += (1/len(list(G.adj[i]))) * all_client_w[j][k]
        w_collect.append(w_curr)
    w_result = copy.deepcopy(w_collect[0])

    for k in w_result.keys():
        for i in range(1, len(w_collect)):
            w_result[k] += w_collect[i][k]
        w_result[k] = torch.div(w_result[k], len(w_collect))
        w_result[k] += global_w[k]

    return w_result



# def Fedqfair(pre_global_w, Delta):
#     for d in Delta:
#         for k in pre_global_w.keys():
#             if type(Delta[k]) != int:
#                 if pre_global_w[k].size() == d[k].size():
#                     pre_global_w[k] -= d[k]
#                 else:
#                     pass
#             else:
#                 pass
#     return pre_global_w

def Fedqfair(weights_before, Deltas, hs):
    demominator = np.sum(np.asarray(hs))
    num_clients = len(Deltas)
    scaled_deltas = []
    for client_delta in Deltas:
        scaled_deltas.append([client_delta[layer] * 1.0 / demominator for layer in client_delta])

    updates = []
    for i in range(len(Deltas[0])):
        tmp = scaled_deltas[0][i]
        for j in range(1, len(Deltas)):
            tmp += scaled_deltas[j][i]
        updates.append(tmp)

    coo = 0
    for val in updates:
        with torch.no_grad():
            list(weights_before.parameters())[coo] -= val
            coo += 1


    # for d in Delta:
    #     for k in pre_global_w.keys():
    #         if type(Delta[k]) != int:
    #             if pre_global_w[k].size() == d[k].size():
    #                 pre_global_w[k] -= d[k]
    #             else:
    #                 pass
    #         else:
    #             pass
    return weights_before


# def meta_agg(self, solns, weight_before, num_qry_samples):
#     aggregate_grads_weighted(solns=solns, weights_before=weight_before, num_samples=num_qry_samples)

def meta_agg(solns, weights_before, num_samples):
    m = len(solns)
    g = []
    for i in range(len(solns[0])):

        grad_sum = torch.zeros_like(solns[0][i])
        
        total_sz = 0
        
        for ic, sz in enumerate(num_samples):
            grad_sum += solns[ic][i] * sz
            total_sz += sz
            # 累加之后, 进行梯度下降
        g.append(grad_sum / total_sz)
    
    global_w = [u - (v * 1e-2 / m) for u, v in zip(weights_before, g)]
    return global_w
    # pre_global_w len = 62
    # # self.outer_opt.increase_n()
    # for i in range(len(weights_before)):
    #     # 这是一个 in-place 的函数
    #     self.outer_opt(weights_before[i], g[i], i=i)


def meta_agg_our_4(solns, weights_before, num_samples, weighting, balance):
    m = len(solns)
    g = []
    for i in range(len(solns[0])):
        
        grad_sum = torch.zeros_like(solns[0][i])
        # # for our 4
        # WWW = np.float_power(weighting[0], 1 - balance)

        # # for our 4_1
        WWW = np.float_power(weighting[0], balance)
        grad_sum = grad_sum * WWW
        total_sz = 0
        
        for ic, sz in enumerate(num_samples):
            # # for our 4
            # WWW = np.float_power(weighting[ic], 1 - balance)

            # # for our 4_1
            WWW = np.float_power(weighting[ic], balance)
            grad_sum += solns[ic][i] * sz * WWW
            total_sz += sz
        g.append(grad_sum / total_sz)
    
    global_w = [u - (v * 1e-2 / m) for u, v in zip(weights_before, g)]
    return global_w