from mpi4py import MPI

from .FedAVGAggregator import FedAVGAggregator
from .FedAVGTrainer import FedAVGTrainer
from .FedAvgClientManager import FedAVGClientManager
from .FedAvgServerManager import FedAVGServerManager
from .MyModelTrainer import MyModelTrainer
from .SSLDittoTrainer import SSLDittoTrainer
from .SSLperFedAvgTrainer import SSLperFedAvgTrainer
from .utils import clear_cache_for_personalized_model


def FedML_init():
    comm = MPI.COMM_WORLD
    process_id = comm.Get_rank()
    worker_number = comm.Get_size()
    return comm, process_id, worker_number


def FedML_FedPer_distributed(process_id, worker_number, device, comm, model, train_data_num, train_data_global,
                             test_data_global,
                             train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args,
                             model_trainer=None, preprocessed_sampling_lists=None):
    if process_id == 0:
        init_server(args, device, comm, process_id, worker_number, model, train_data_num, train_data_global,
                    test_data_global, train_data_local_dict, test_data_local_dict, train_data_local_num_dict,
                    model_trainer, preprocessed_sampling_lists)
    else:
        init_client(args, device, comm, process_id, worker_number, model, train_data_num, train_data_local_num_dict,
                    train_data_local_dict, test_data_local_dict, model_trainer)


def init_server(args, device, comm, rank, size, model, train_data_num, train_data_global, test_data_global,
                train_data_local_dict, test_data_local_dict, train_data_local_num_dict, model_trainer,
                preprocessed_sampling_lists=None):
    model_trainer = _create_trainer_by_per_alg(model, args, device)
    model_trainer.set_id(-1)

    # aggregator
    worker_num = size - 1
    aggregator = FedAVGAggregator(train_data_global, test_data_global, train_data_num,
                                  train_data_local_dict, test_data_local_dict, train_data_local_num_dict,
                                  worker_num, device, args, model_trainer)

    # start the distributed training
    backend = args.backend
    if preprocessed_sampling_lists is None:
        server_manager = FedAVGServerManager(args, aggregator, comm, rank, size, backend)
    else:
        server_manager = FedAVGServerManager(args, aggregator, comm, rank, size, backend,
                                             is_preprocessed=True,
                                             preprocessed_client_lists=preprocessed_sampling_lists)
    server_manager.send_init_msg()
    server_manager.run()


def init_client(args, device, comm, process_id, size, model, train_data_num, train_data_local_num_dict,
                train_data_local_dict, test_data_local_dict, model_trainer=None):
    client_index = process_id - 1
    model_trainer = _create_trainer_by_per_alg(model, args, device)
    model_trainer.set_id(client_index)
    backend = args.backend
    trainer = FedAVGTrainer(client_index, train_data_local_dict, train_data_local_num_dict, test_data_local_dict,
                            train_data_num, device, args, model_trainer)
    client_manager = FedAVGClientManager(args, trainer, comm, process_id, size, backend)
    client_manager.run()


def _create_trainer_by_per_alg(model, args, device):
    # clear cache of personalized model
    clear_cache_for_personalized_model(args)

    if args.per_optimizer == "FedAvg":
        model_trainer = MyModelTrainer(model, args)
    elif args.per_optimizer == "Ditto":
        model_trainer = SSLDittoTrainer(model, args, device)
    elif args.per_optimizer == "perFedAvg":
        model_trainer = SSLperFedAvgTrainer(model, args, device)
    else:
        raise Exception("no such trainer!")

    return model_trainer
