import torch
import numpy as np
import copy
import gc

from tool.logger import *
from tool.utils import get_parameters, set_parameters
from algorithm.client_selection import client_selection

def Fed_AVG_NN(device,
               global_model,
               algorithm_step_T, num_clients_K, communication_round_I, FL_fraction, FL_drop_rate, local_step_size,
               training_dataloaders,
               training_dataset,
               client_dataset_list):
    # Training process
    logger.info("Training process")

    training_dataset_size = len(training_dataset)
    client_datasets_size_list = [len(item) for item in client_dataset_list]
    average_weight = np.array([float(i / training_dataset_size) for i in client_datasets_size_list])

    # Parameter Initialization
    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    client_loss_list = [[] for _ in range(num_clients_K)]

    # Client selection
    logger.info(f"********** Client selection **********")
    idxs_users = client_selection(
        client_num=num_clients_K,
        fraction=FL_fraction,
        dataset_size=training_dataset_size,
        client_dataset_size_list=client_datasets_size_list,
        drop_rate=FL_drop_rate,
        style="FedAvg",
    )
    logger.info(f"********** Select client list: {idxs_users} **********")

    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        for i in idxs_users:

            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]
            # logger.info(f"########## Algorithm Epoch: {iter_t + 1} / {algorithm_epoch_T}; "
            #             f"Client: {i + 1};  ##########")

            # Optimizing One Step By SGD Algorithm
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].to(device)
                sub_batch_loss = 0
                for sub in range(0, len(X), 64):
                    sbatch_size = X[sub:sub + 64].shape[0]  # 获取当前批次的样本数量
                    local_prediction = model(X[sub:sub + sbatch_size].reshape(sbatch_size, -1))
                    loss = criterion(local_prediction, y[sub:sub + sbatch_size].long())
                    sub_batch_loss += loss.data
                    loss.backward()
                # local_prediction = model(X).to(device)
                # loss = criterion(local_prediction, y.long())
                # loss.backward()
                client_loss_list[i].append(sub_batch_loss)

                if (iter_t + 1) % 10 == 0:
                    logger.info(f"########## Algorithm Epoch: {iter_t + 1} / {algorithm_step_T}; "
                                f"Client: {i} / {num_clients_K};"
                                f" Loss: {round(float(sub_batch_loss), 4)} ##########")
                optimizer.step()
                optimizer.zero_grad()
                del X, y, local_prediction, loss
                gc.collect()
                torch.cuda.empty_cache()
                break

            # Upgrade the local model list
            local_model_list[i] = model

        # Communicate
        if (iter_t + 1) % communication_round_I == 0:
            logger.info(f"********** Communicate: {(iter_t + 1) / communication_round_I} **********")

            # Global operation
            logger.info("********** Parameter aggregation **********")
            theta_list = []
            for id in idxs_users:
                selected_model = local_model_list[id]
                theta_list.append(get_parameters(selected_model))
            theta_list = np.array(theta_list, dtype=object)
            theta_avg = np.average(theta_list, axis=0, weights=[average_weight[j] for j in idxs_users]).tolist()
            # theta_avg = np.mean(theta_list, 0).tolist()
            global_model = set_parameters(global_model, theta_avg)

            # Parameter Distribution
            logger.info("********** Parameter distribution **********")
            local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

            # Client Reselection
            logger.info(f"********** Client selection **********")
            idxs_users = client_selection(
                client_num=num_clients_K,
                fraction=FL_fraction,
                dataset_size=training_dataset_size,
                client_dataset_size_list=client_datasets_size_list,
                drop_rate=FL_drop_rate,
                style="FedAvg",
            )
            logger.info(f"********** Select client list: {idxs_users} **********")

    # logger.info(f"########## Federated Average client loss list: {client_loss_list}; "
    #             f" client epoch loss list: {client_epoch_loss_list}; ##########")

    logger.info("Training finish, return global model and local model list")
    return global_model, local_model_list
