import torch
import numpy as np
import copy
import math
import gc

from tool.logger import *
from tool.utils import get_parameters, set_parameters, save_model
from algorithm.client_selection import client_selection


def Fed_Prox_NN(device,
                global_model,
                algorithm_step_T, num_clients_K, communication_round_I, FL_fraction, FL_drop_rate, local_step_size,
                training_dataloaders,
                training_dataset,
                client_dataset_list):
    logger.info("Training process")
    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    training_dataset_size = len(training_dataset)
    client_datasets_size_list = [len(item) for item in client_dataset_list]
    average_weight = np.array([float(i / training_dataset_size) for i in client_datasets_size_list])

    # Hyperparameter
    mu = 1

    # Parameter Initialization
    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

    # Client selection
    logger.info("********** Client selection **********")
    idxs_users = client_selection(
        client_num=num_clients_K,
        fraction=FL_fraction,
        dataset_size=training_dataset_size,
        client_dataset_size_list=client_datasets_size_list,
        drop_rate=FL_drop_rate,
        style="FedAvg",
    )
    logger.info(f"********** Select client list: {idxs_users} **********")

    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        avg_loss_over_step, avg_ce_over_step, avg_regularization_term_over_step = 0, 0, 0
        for i in idxs_users:
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            model.zero_grad()
            model.to(device)
            client_i_dataloader = training_dataloaders[i]

            # Local Optimizing By SGD
            for batch_index, batch in enumerate(client_i_dataloader):
                model.zero_grad()
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                sub_batch_loss = 0

                global_params = get_parameters(global_model)
                local_params = get_parameters(model)
                norm = 0
                for i_layer in range(len(global_params)):
                    f = torch.tensor(global_params[i_layer]) - torch.tensor(local_params[i_layer])
                    norm += torch.norm(f, p=2)

                for sub in range(0, len(X), 64):
                    sbatch_size = X[sub:sub + 64].shape[0]  # 获取当前批次的样本数量
                    local_prediction = model(X[sub:sub + sbatch_size].reshape(sbatch_size, -1))
                    loss = criterion(local_prediction, y[sub:sub + sbatch_size].long())
                    avg_ce_over_step += round(float(loss), 4)
                    loss += (mu / 2) * norm * norm
                    sub_batch_loss += round(float(loss), 4)
                    loss.backward()

                avg_regularization_term_over_step += norm
                avg_loss_over_step += round(float(sub_batch_loss), 4)

                # if (iter_t + 1) % 5 == 0:
                #     logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                #                 f"Client: {i} / {num_clients_K}; "
                #                 f"Avg Loss in Batch: {round(float(sub_batch_loss), 4)} ##########")

                optimizer.step()
                optimizer.zero_grad()
                # del X, y, local_prediction, loss
                # gc.collect()
                # torch.cuda.empty_cache()
                break

            # Upgrade the local model list
            local_model_list[i] = model

        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg CE over Client: {round(float(avg_ce_over_step / len(idxs_users)), 4)}; "
                        f"Avg RT over Client: {round(float(avg_regularization_term_over_step / len(idxs_users)), 4)}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step / len(idxs_users)), 4)}; ##########")

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")

            # Global operation
            logger.info("********** Parameter aggregation **********")
            theta_list = []
            for id in idxs_users:
                selected_model = local_model_list[id]
                theta_list.append(get_parameters(selected_model))
            theta_list = np.array(theta_list, dtype=object)
            theta_avg = np.average(theta_list, axis=0, weights=[average_weight[j] for j in idxs_users]).tolist()
            # theta_avg = np.mean(theta_list, 0).tolist()
            global_model = set_parameters(global_model, theta_avg)

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Client Reselection
            logger.info(f"********** Client selection **********")
            idxs_users = client_selection(
                client_num=num_clients_K,
                fraction=FL_fraction,
                dataset_size=training_dataset_size,
                client_dataset_size_list=client_datasets_size_list,
                drop_rate=FL_drop_rate,
                style="FedAvg",
            )
            logger.info(f"********** Select client list: {idxs_users} **********")

    # logger.info(f"########## Federated Average client loss list: {client_loss_list}; "
    #             f" client epoch loss list: {client_epoch_loss_list}; ##########")

    logger.info("Training finish, return global model and local model list")
    return global_model, local_model_list
