import torch
import copy
import numpy as np
import gc

from tool.logger import *
from tool.utils import get_parameters, set_parameters, save_model


def ST_LR(device, global_model, algorithm_step_T, num_clients_K, local_step_size, training_dataloaders):
    # Training process
    logger.info("Training process")
    criterion = torch.nn.BCELoss(reduction='mean')

    # Parameter Initialization
    global_model.train()
    local_model_list = [copy.deepcopy(global_model) for i in range(num_clients_K)]

    client_loss_list = [[] for _ in range(num_clients_K)]

    for iter_t in range(algorithm_step_T):
        # Simulate The Client Parallel Process
        avg_loss_over_step = 0
        for i in range(num_clients_K):
            model = local_model_list[i]
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=local_step_size)
            model.zero_grad()
            model.to(device)
            client_i_dataloader = training_dataloaders[i]

            # Local Optimizing By SGD
            for batch_index, batch in enumerate(client_i_dataloader):
                X = batch["X"].to(device)
                y = batch["y"].reshape(-1, 1).to(device)
                sub_batch_loss = 0
                for sub in range(0, len(X), 64):
                    sbatch_size = X[sub:sub + 64].shape[0]  # 获取当前批次的样本数量
                    local_prediction = model(X[sub:sub + sbatch_size]).to(device)
                    loss = criterion(local_prediction, y[sub:sub + sbatch_size].float())
                    sub_batch_loss += round(float(loss), 4)
                    loss.backward()
                avg_loss_over_step += round(float(sub_batch_loss), 4)
                client_loss_list[i].append(round(float(sub_batch_loss), 4))

                # if (iter_t + 1) % 5 == 0:
                #     logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                #                 f"Client: {i} / {num_clients_K}; "
                #                 f"Avg Loss in Batch: {round(float(sub_batch_loss), 4)} ##########")

                optimizer.step()
                optimizer.zero_grad()
                del X, y, local_prediction, loss
                gc.collect()
                torch.cuda.empty_cache()
                break

            # Upgrade the local model list
            local_model_list[i] = model

        if (iter_t + 1) % 1 == 0:
            logger.info(f"########## Step: {iter_t + 1} / {algorithm_step_T}; "
                        f"Avg Loss over Client: {round(float(avg_loss_over_step / num_clients_K), 4)} ##########")

    logger.info("Training finish, return the local model list")
    return local_model_list
