import logging
from typing import List

from utils.exp_configs import create_exp_config_setup
from clients.base import Client


def base_train_function(
    num_epochs,
    model,
    optimizer,
    train_dataloader,
    public_loader,
    device,
    train_args,
):
    def train_function(num_epochs=num_epochs):
        model.to(device)

        for epoch in range(num_epochs):
            logging.info(f"    | epoch: {epoch}")
            for idx, batch in enumerate(train_dataloader):
                # get the inputs;
                img, labels = batch
                img = img.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward + backward + optimize
                outputs = model(pixel_values=img, labels=labels)
                loss, logits = outputs.loss, outputs.logits

                # apply loss based on the method

                loss.backward()
                optimizer.step()

        model.to("cpu")
        return loss

    return train_function


def create_local_training_aggregation_method(*args, **kwargs):
    pass


def create_local_training_setup(
    args,
    models_path,
    dataset_path_dict,
    device,
) -> List[Client]:
    # Add local training args
    num_local_epochs = 1

    train_args = {}

    return create_exp_config_setup(
        args=args,
        models_path=models_path,
        dataset_path_dict=dataset_path_dict,
        device=device,
        num_local_epochs=num_local_epochs,
        local_train_function=base_train_function,
        train_args=train_args,
    )
