import torch
import numpy as np
import copy
import gc
import math
from tool.logger import *
from tool.utils import get_parameters, set_parameters, save_model
from algorithm.client_selection import client_selection
from algorithm.Optimizers import Scaffold_Optimizer


def Scaffold_LR(device,
                global_model,
                algorithm_step_T, num_clients_K, communication_round_I, FL_fraction, FL_drop_rate, local_step_size,
                training_dataloaders,
                training_dataset,
                client_dataset_list,
                ):
    logger.info("Training process")
    # criterion = torch.nn.BCELoss(reduction='mean')

    training_dataset_size = len(training_dataset)
    client_datasets_size_list = [len(item) for item in client_dataset_list]
    average_weight = np.array([float(i / training_dataset_size) for i in client_datasets_size_list])

    # Hyperparameter
    slr = 1

    # Parameter Initialization
    # 将所有的控制变量都初始化为0
    for k, v in global_model.named_parameters():
        global_model.control[k] = torch.zeros_like(v.data)
        global_model.delta_control[k] = torch.zeros_like(v.data)
        global_model.delta_y[k] = torch.zeros_like(v.data)

    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

    criterion = torch.nn.BCELoss(reduction='mean')

    client_loss_list = [[] for _ in range(num_clients_K)]

    # Client selection
    logger.info("********** Client selection **********")
    idxs_users = client_selection(
        client_num=num_clients_K,
        fraction=FL_fraction,
        dataset_size=training_dataset_size,
        client_dataset_size_list=client_datasets_size_list,
        drop_rate=FL_drop_rate,
        style="FedAvg",
    )
    logger.info(f"********** Select client list: {idxs_users} **********")

    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        avg_loss_over_step, avg_ce_over_step, avg_regularization_term_over_step = 0, 0, 0
        for i in idxs_users:
            model = local_model_list[i]
            model.train()
            model.zero_grad()
            model.to(device)
            optimizer = Scaffold_Optimizer(model.parameters(), method="sgd", learning_rate=local_step_size)
            client_i_dataloader = training_dataloaders[i]
            # 保存全局模型的参数备用
            x = copy.deepcopy(model)
            # Local Optimizing
            for batch_index, batch in enumerate(client_i_dataloader):
                model.zero_grad()
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                sub_batch_loss = 0
                try:
                    local_prediction = model(X).to(device)
                    loss = criterion(local_prediction, y.float())
                    sub_batch_loss += round(float(loss), 4)
                    loss.backward()
                    avg_loss_over_step += round(float(loss), 4)
                    client_loss_list[i].append(round(float(loss), 4))

                    # if (iter_t + 1) % 5 == 0:
                    #     logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                    #                 f"Client: {i} / {num_clients_K}; "
                    #                 f"Avg Loss in Batch: {round(float(sub_batch_loss), 4)} ##########")

                    optimizer.step(device, global_model.control, model.control)
                    optimizer.zero_grad()
                    # del X, y, local_prediction, loss
                    # gc.collect()
                    # torch.cuda.empty_cache()
                    break
                except RuntimeError:
                    logger.info("Something wrong happen in inference. Skipping this batch of data !")
                    continue


            # Upgrade the local model list
            local_model_list[i] = model

            # 更新参数ci
            # temp保存了客户端模型参数y_i
            temp = {}
            for k, v in model.named_parameters():
                temp[k] = v.data

            # TODO:temp[k] 就是y_i, v.data就是x（对应于论文中公式的符号）
            for k, v in x.named_parameters():
                # print(model.control[k].is_cuda, global_model.control[k].is_cuda)
                model.control[k] = model.control[k].to(device)
                global_model.control[k] = global_model.control[k].to(device)
                x.control[k] = x.control[k].to(device)

                model.control[k] = model.control[k] - global_model.control[k] + (v.data - temp[k]) / (
                        algorithm_step_T / communication_round_I * 0.005)
                model.delta_y[k] = temp[k] - v.data
                model.delta_control[k] = model.control[k] - x.control[k]

                # model.control[k] = model.control[k].cpu()
                # global_model.control[k] = global_model.control[k].cpu()
                # x.control[k] = x.control[k].cpu()

        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step / len(idxs_users)), 4)} ##########")

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")

            # Global operation
            logger.info("********** Parameter aggregation **********")
            x = {}
            c = {}
            for k, v in local_model_list[0].named_parameters():
                x[k] = torch.zeros_like(v.data)
                c[k] = torch.zeros_like(v.data)
            for j in idxs_users:
                for k, v in local_model_list[j].named_parameters():
                    local_model_list[j].delta_y[k] = local_model_list[j].delta_y[k].to(device)
                    local_model_list[j].delta_control[k] = local_model_list[j].delta_control[k].to(device)
                    x[k] = x[k].to(device)
                    c[k] = c[k].to(device)

                    x[k] += local_model_list[j].delta_y[k] / len(idxs_users)
                    c[k] += local_model_list[j].delta_control[k] / len(idxs_users)

                    local_model_list[j].delta_y[k] = local_model_list[j].delta_y[k].cpu()
                    local_model_list[j].delta_control[k] = local_model_list[j].delta_control[k].cpu()
                    # x[k] = x[k].cpu()
                    # c[k] = c[k].cpu()
            for k, v in global_model.named_parameters():
                v.data += x[k].data * slr
                global_model.control[k].data += c[k].data

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Client Reselection
            logger.info("********** Client selection **********")
            idxs_users = client_selection(
                client_num=num_clients_K,
                fraction=FL_fraction,
                dataset_size=training_dataset_size,
                client_dataset_size_list=client_datasets_size_list,
                drop_rate=FL_drop_rate,
                style="FedAvg",
            )
            logger.info(f"********** Select client list: {idxs_users} **********")

        # Save model
        # if (iter_t) % param_dict["save_checkpoint_rounds"] == 0 and iter_t != 0:
        #     save_model(param_dict, global_model, local_model_list, iter_t)

    # logger.info(f"########## Federated Average client loss list: {client_loss_list}; ##########")

    logger.info("Training finish, return global model and local model list")
    return global_model, local_model_list
