import torch
import copy
import numpy as np
import gc

from tool.logger import *
from tool.utils import get_parameters, set_parameters, save_model

def ST_NN(device, global_model, algorithm_step_T, num_clients_K, local_step_size, training_dataloaders):
    # Training process
    logger.info("Training process")

    # Parameter Initialization
    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for _ in range(num_clients_K)]

    criterion = torch.nn.CrossEntropyLoss(reduction='mean')

    client_loss_list = [[] for _ in range(num_clients_K)]

    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        for i in range(num_clients_K):

            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            client_i_dataloader = training_dataloaders[i]
            # logger.info(f"########## Algorithm Epoch: {iter_t + 1} / {algorithm_epoch_T}; "
            #             f"Client: {i + 1};  ##########")

            # Optimizing One Step By SGD Algorithm
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].to(device)
                sub_batch_loss = 0
                for sub in range(0, len(X), 64):
                    sbatch_size = X[sub:sub + 64].shape[0]  # 获取当前批次的样本数量
                    local_prediction = model(X[sub:sub + sbatch_size].reshape(sbatch_size, -1))
                    loss = criterion(local_prediction, y[sub:sub + sbatch_size].long())
                    sub_batch_loss += loss.data
                    loss.backward()
                # local_prediction = model(X).to(device)
                # loss = criterion(local_prediction, y.long())
                # loss.backward()
                client_loss_list[i].append(sub_batch_loss)

                if (iter_t + 1) % 10 == 0:
                    logger.info(f"########## Algorithm Epoch: {iter_t + 1} / {algorithm_step_T}; "
                                f"Client: {i} / {num_clients_K};"
                                f" Loss: {round(float(sub_batch_loss), 4)} ##########")
                optimizer.step()
                optimizer.zero_grad()
                del X, y, local_prediction, loss
                gc.collect()
                torch.cuda.empty_cache()
                break

            # Upgrade the local model list
            local_model_list[i] = model

    logger.info("Training finish, return the local model list")
    return local_model_list


