from mpi4py import MPI
import sys
import os
import torch
import json

from .FedXDDClientManager import FedXDDClientMananger
from .FedXDDClientTrainer import FedXDDClientTrainer, FedXDDClientTrainer_DA#, FedXDDClientTrainer_DP
from .FedXDDServerManager import FedXDDServerMananger
from .FedXDDServerTrainer import FedXDDServerTrainer


sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "./../")))
from data_preprocessing.lda_run_config import LDARunConfig
from model import LiteResidualModule, build_network_from_config


def FedML_init():
    comm = MPI.COMM_WORLD
    process_id = comm.Get_rank()
    worker_number = comm.Get_size()
    return comm, process_id, worker_number


# FedXDD single device (DA completed)

def FedML_FedXDD_DA_completed(args, process_id, worker_number, device, comm, compact_model, netF, netB, netC,
                              source_data, train_data_local_dict, train_data_local_num_dict, test_data_local_dict):
    if process_id == 0:
        init_server(args, device, comm, process_id, worker_number, compact_model, source_data)
    else:
        init_client_DA(args, device, comm, process_id, worker_number, compact_model, netF, netB, netC,
                       train_data_local_dict, test_data_local_dict, train_data_local_num_dict)


def init_client_DA(args, device, comm, process_id, size, compact_model, netF, netB, netC,
                   train_data_local_dict, test_data_local_dict, train_data_local_num_dict):
    client_ID = process_id - 1
    client_num = size - 1

    trainer = FedXDDClientTrainer_DA(client_ID, client_num, train_data_local_dict, test_data_local_dict, train_data_local_num_dict,
                                     device, compact_model, netF, netB, netC, args)

    client_manager = FedXDDClientMananger(args, trainer, comm, process_id, size)
    client_manager.run()



# FedXDD multiple devices

def FedML_FedXDD_distributed(args, process_id, worker_number, device, comm, compact_model, netF, netB, netC, optimizer,
                             source_data, train_data_local_dict, train_data_local_num_dict, test_data_local_dict,
                             test_data, secure_agg=False, seed_list=None, sign_list=None):
    if process_id == 0:
        init_server(args, device, comm, process_id, worker_number, compact_model, source_data)
    else:
        init_client(args, device, comm, process_id, worker_number, compact_model, netF, netB, netC, optimizer,
                    train_data_local_dict, test_data_local_dict, train_data_local_num_dict, test_data,
                    secure_agg, seed_list, sign_list)



def init_server(args, device, comm, rank, size, compact_model, source_data):

    # aggregator
    client_num = size - 1
    server_trainer = FedXDDServerTrainer(client_num, source_data, device, compact_model, args)

    # start the distributed training
    server_manager = FedXDDServerMananger(args, server_trainer, comm, rank, size)
    server_manager.run()


def init_client(args, device, comm, process_id, size, compact_model, netF, netB, netC, optimizer,
                train_data_local_dict, test_data_local_dict, train_data_local_num_dict, test_data,
                secure_agg=False, seed_list=None, sign_list=None):
    client_ID = process_id - 1
    client_num = size - 1

    trainer = FedXDDClientTrainer(client_ID, client_num, train_data_local_dict, test_data_local_dict,
                                  train_data_local_num_dict, device, compact_model, netF, netB, netC,
                                  optimizer, args, test_data, secure_agg, seed_list, sign_list)

    client_manager = FedXDDClientMananger(args, trainer, comm, process_id, size)
    client_manager.run()

