from mpi4py import MPI

from .FedSSLAggregator import FedSSLAggregator
from .FedSSLClientManager import FedSSLClientManager
from .FedSSLServerManager import FedSSLServerManager
from .FedSSLTrainer import FedSSLTrainer
from .SSLDittoTrainer import SSLDittoTrainer
from .SSLEvaluationTrainer import SSLEvaluationTrainer
from .SSLFedAvgTrainer import SSLFedAvgTrainer
from .SSLpFedMeTrainer import SSLpFedMeTrainer
from .SSLperFedAvgTrainer import SSLperFedAvgTrainer
from .SSLperFedSimSiamTrainer import SSLperFedSimSiamTrainer
from .SSLperFedTrainer import SSLperFedTrainer


def FedML_init():
    comm = MPI.COMM_WORLD
    process_id = comm.Get_rank()
    worker_number = comm.Get_size()
    return comm, process_id, worker_number


def FedML_FedSSL_distributed(process_id, worker_number, device, comm, model, train_data_num, train_data_global,
                             test_data_global,
                             train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args,
                             model_trainer=None, preprocessed_sampling_lists=None):
    if process_id == 0:
        init_server(args, device, comm, process_id, worker_number, model, train_data_num, train_data_global,
                    test_data_global, train_data_local_dict, test_data_local_dict, train_data_local_num_dict,
                    model_trainer, preprocessed_sampling_lists)
    else:
        init_client(args, device, comm, process_id, worker_number, model, train_data_num, train_data_local_num_dict,
                    train_data_local_dict, test_data_local_dict, model_trainer)


def init_server(args, device, comm, rank, size, model, train_data_num, train_data_global, test_data_global,
                train_data_local_dict, test_data_local_dict, train_data_local_num_dict, model_trainer,
                preprocessed_sampling_lists=None):
    if model_trainer is None:
        model_trainer = _create_trainer_by_per_alg(model, args, device)
    model_trainer.set_id(-1)

    # aggregator
    worker_num = size - 1
    aggregator = FedSSLAggregator(train_data_global, test_data_global, train_data_num,
                                  train_data_local_dict, test_data_local_dict, train_data_local_num_dict,
                                  worker_num, device, args, model_trainer)

    # start the distributed training
    if preprocessed_sampling_lists is None:
        server_manager = FedSSLServerManager(args, aggregator, comm, rank, size)
    else:
        server_manager = FedSSLServerManager(args, aggregator, comm, rank, size,
                                             backend="MPI",
                                             is_preprocessed=True,
                                             preprocessed_client_lists=preprocessed_sampling_lists)
    server_manager.send_init_msg()
    server_manager.run()


def init_client(args, device, comm, process_id, size, model, train_data_num, train_data_local_num_dict,
                train_data_local_dict, test_data_local_dict, model_trainer=None):
    index = process_id - 1
    if model_trainer is None:
        model_trainer = _create_trainer_by_per_alg(model, args, device)
    model_trainer.set_id(index)

    trainer = FedSSLTrainer(index, train_data_local_dict, train_data_local_num_dict, test_data_local_dict,
                            train_data_num, device, args, model_trainer)
    client_manager = FedSSLClientManager(args, trainer, comm, process_id, size)
    client_manager.run()


def _create_trainer_by_per_alg(model, args, device):
    if args.ssl_is_linear_eval == 1:
        model_trainer = SSLEvaluationTrainer(model, args, device)
    else:
        if args.pssl_optimizer == "FedAvg":
            model_trainer = SSLFedAvgTrainer(model, args, device)
        elif args.pssl_optimizer == "FedAvg_LocalAdaptation":
            model_trainer = SSLFedAvgTrainer(model, args, device)
        elif args.pssl_optimizer == "pFedMe":
            model_trainer = SSLpFedMeTrainer(model, args, device)
        elif args.pssl_optimizer == "Ditto":
            model_trainer = SSLDittoTrainer(model, args, device)
        elif args.pssl_optimizer == "perFedAvg":
            model_trainer = SSLperFedAvgTrainer(model, args, device)
        elif args.pssl_optimizer == "perSSL":
            # our method 1
            model_trainer = SSLperFedTrainer(model, args, device)
        elif args.pssl_optimizer == "perSimSiam":
            # our method 2
            model_trainer = SSLperFedSimSiamTrainer(model, args, device)
        elif args.pssl_optimizer == "perSimSiam2":
            # our method 3
            model_trainer = SSLperFedSimSiamTrainer(model, args, device)
        else:
            raise Exception("no such trainer!")
    return model_trainer
