from mpi4py import MPI

from fedml_api.distributed.fednas_extension.FedNASAggregator import FedNASAggregator
from fedml_api.distributed.fednas_extension.FedNASClientManager import FedNASClientManager
from fedml_api.distributed.fednas_extension.FedNASServerManager import FedNASServerManager
from fedml_api.distributed.fednas_extension.FedNASTrainer import FedNASTrainer


def FedML_init():
    comm = MPI.COMM_WORLD
    process_id = comm.Get_rank()
    worker_number = comm.Get_size()
    return comm, process_id, worker_number


def FedML_FedNAS_distributed(process_id, worker_number, device, comm,
                             model, train_data_num, train_data_global, test_data_global,
                             train_data_local_num_dict, train_data_local_dict, test_data_local_dict, args, client_num_in_total, teacher_model,
                             local_model):
    if process_id == 0:
        init_server(args, device, comm, process_id, worker_number, model, train_data_num,
                    train_data_global,
                    test_data_global, train_data_local_dict, test_data_local_dict, client_num_in_total, teacher_model)
        # else remove last two arguments
    else:
        init_client(args, device, comm, process_id, worker_number, model, train_data_num,
                    train_data_local_num_dict, train_data_local_dict, test_data_local_dict, teacher_model, local_model)


def init_server(args, device, comm, process_id, worker_number, model, train_data_num, train_data_global,
                test_data_global, train_data_local_dict, test_data_local_dict, client_num_in_total, teacher_model):
    # aggregator
    client_num = worker_number - 1
    aggregator = FedNASAggregator(train_data_global, test_data_global, train_data_local_dict, test_data_local_dict,
                                  train_data_num, client_num, model, device, args, client_num_in_total)

    # start the distributed training
    server_manager = FedNASServerManager(args, comm, process_id, worker_number, aggregator)
    server_manager.run()


def init_client(args, device, comm, process_id, worker_number, model, train_data_num,
                    train_data_local_num_dict, train_data_local_dict, test_data_local_dict, teacher_model, local_model):
    # trainer
    client_ID = process_id - 1
    trainer = FedNASTrainer(client_ID,
                            train_data_num, train_data_local_num_dict,
                            train_data_local_dict, test_data_local_dict,
                            model, device, args, teacher_model, local_model)
    client_manager = FedNASClientManager(args, comm, process_id, worker_number, trainer)
    client_manager.run()
