import collections
import logging
import math
import sys
import copy

import torch
import torch.distributed as dist

def schedule_lr(t, args):
    if t >= 0.75 * args.round_num:
        return args.lr * 0.01
    if t >= 0.5 * args.round_num:
        return args.lr * 0.1
    return args.lr

def combine_reduction(reductions, new_reductions, t):
    for i, (reduction, new_reduction) in enumerate(zip(reductions, new_reductions)):
        for j, (r, n_r) in enumerate(zip(reduction, new_reduction)):
            reductions[i][j] = (r * t + n_r) / (t + 1)
    return reductions


def fedAvg_communicate(global_model, models, args, Ls):
    trained_dict = global_model.state_dict()
    trained_dict_update = {}
    
    for i, model in enumerate(models):
        model_dict = model.state_dict()
        for k, v in model_dict.items():
            update_data = (model_dict[k].data - trained_dict[k].data) * Ls[i]
            if k in trained_dict_update:
                trained_dict_update[k] += update_data
            else:
                trained_dict_update[k] = update_data

    for k, v in trained_dict.items():
        trained_dict[k] =  trained_dict[k] + (args.globallr * (trained_dict_update[k] / sum(Ls)))

    
    global_model.load_state_dict(trained_dict)
    return global_model

def fedNova_communicate(global_model, models, args, Ls):
    # trained_update = [torch.zeros_like(p.data.clone().detach()) for p in global_model.parameters()]
    trained_dict = global_model.state_dict()
    trained_dict_update = {}
    # reductions = []
    total = sum(Ls)

    for i, model in enumerate(models):
        model_dict = model.state_dict()
        for k, v in model_dict.items():
            update_data = model_dict[k].data - trained_dict[k].data
            if k in trained_dict_update:
                trained_dict_update[k] += (model_dict[k].data - trained_dict[k].data) * Ls[i]
            else:
                trained_dict_update[k] = (model_dict[k].data - trained_dict[k].data) * Ls[i]
        
    # for i, model in enumerate(models):
    #     for j, (p, p_g) in enumerate(zip(model.parameters(), global_model.parameters())):
    #         trained_update[j] += (p.data.clone().detach() - p_g.data.clone().detach()) * Ls[i]
    
    # for i, model in enumerate(models):
    #     reduction = []
    #     for (p, p_g, update) in zip(model.parameters(), global_model.parameters(), trained_update):
    #             reduction.append(((-p_g.data.clone().detach() + p.data.clone().detach()) * Ls[i] - update) / (schedule_lr(t, args) * args.localepoch * Ls[i] * total))
    #     reductions.append(reduction)

    for k, v in trained_dict.items():
        trained_dict[k] =  trained_dict[k] + args.globallr * (trained_dict_update[k] / total)

    global_model.load_state_dict(trained_dict)

    return global_model