import torch
import copy


def aggregate_special(clients, client_fraction, model, lr_global, lambda_reg=None, theta_reg=None, first_round=True):
    params = copy.deepcopy(model.state_dict())
    params_t = copy.deepcopy(model.state_dict())
    for client, fraction in zip(clients, client_fraction):
        params_i_t = copy.deepcopy(client.model.state_dict())
        for key in params.keys():
            params[key] = lr_global * (params_i_t[key] - params_t[key]) * fraction + params[key]

    if first_round == False:
        for key in params.keys():
            params[key] = 1 / (1 + lambda_reg) * (
                    params[key] + lambda_reg * theta_reg[key].to(params[key].device)
            )

    return params


def aggregate_basic(clients, client_fraction, model):
    params = copy.deepcopy(model.state_dict())
    params_t = copy.deepcopy(model.state_dict())
    for client, fraction in zip(clients, client_fraction):
        params_i_t = copy.deepcopy(
            client.model.state_dict())
        for key in params.keys():
            params[key] = (params_i_t[key] - params_t[key]) * fraction + params[key]

    return params

def aggregate_FDCL2(clients, client_fraction, model, lr_global):
    params = copy.deepcopy(model.state_dict())
    params_t = copy.deepcopy(model.state_dict())
    for client, fraction in zip(clients, client_fraction):
        params_i_t = copy.deepcopy(
            client.model.state_dict())
        for key in params.keys():
            params[key] = lr_global * (params_i_t[key] - params_t[key]) * fraction + params[key]

    return params

def aggregate_FedCIL(clients, client_fraction, generator):
    new_generator = copy.deepcopy(generator)
    for param in new_generator.parameters():
        param.data = torch.zeros_like(param.data)

    fraction_idx = 0
    for client in clients:
        for server_param, user_param in zip(new_generator.parameters(), client.generator.parameters()):
            server_param.data = server_param.data + user_param.data * client_fraction[fraction_idx]
        fraction_idx += 1

    return new_generator


def aggregate_discriminator(clients, client_fraction, model):
    params = copy.deepcopy(model.state_dict())
    params_t = copy.deepcopy(model.state_dict())
    for client, fraction in zip(clients, client_fraction):
        params_i_t = copy.deepcopy(
            client.discriminator.state_dict())
        for key in params.keys():
            params[key] = (params_i_t[key] - params_t[key]) * fraction + params[key]

    return params
