import logging

from .utils import transform_tensor_to_list
from .message_def import MyMessage
from fedml_core.distributed.communication.message import Message
from fedml_core.distributed.server.server_manager import ServerManager

class FedXDDServerMananger(ServerManager):
    def __init__(self, args, server_trainer, comm=None, rank=0, size=0, backend="MPI"):
        super().__init__(args, comm, rank, size, backend)

        self.server_trainer = server_trainer
        self.round_num = args.comm_round
        self.round_idx = 0
        self.correct = None
        self.test_sample_num = None

        self.count = 0

    def run(self):
        global_model_params = self.server_trainer.get_global_model_params()
        for process_id in range(1, self.size):
            self.send_message_init_config(process_id, global_model_params)
        super().run()

    ### need to revise
    def register_message_receive_handlers(self):
        self.register_message_receive_handler(MyMessage.MSG_TYPE_C2S_SEND_MODEL_AND_DUAL,
                                              self.handle_message_receive_model_from_client)

        self.register_message_receive_handler(MyMessage.MSG_TYPE_C2S_SEND_ACC,
                                              self.handle_message_receive_acc_from_client)

    ### need to revise
    def handle_message_receive_model_from_client(self, msg_params):
        logging.info("handle_message_receive_model_from_client")
        sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER)
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES)


        self.server_trainer.add_local_trained_result(sender_id - 1, model_params, local_sample_number)
        b_all_received = self.server_trainer.check_whether_all_receive()
        logging.info("b_all_received = " + str(b_all_received))
        if b_all_received:
            global_model_params, correct, test_sample_num = self.server_trainer.aggregate_and_train()
            self.correct = correct
            self.test_sample_num = test_sample_num
            # # start the next round
            # self.round_idx += 1
            # if self.round_idx == self.round_num:
            #     self.finish()
            #     print('here')
            #     return

            print("size = %d" % self.size)
            # if self.args.is_mobile == 1:
            #     print("transform_tensor_to_list")
            #     global_model_params = transform_tensor_to_list(global_model_params)

            for receiver_id in range(1, self.size):
                self.send_message_sync_model_to_client(receiver_id, global_model_params)


    def handle_message_receive_acc_from_client(self, msg_params):
        logging.info("handle_message_receive_acc_from_client")
        sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER)
        correct_num = msg_params.get(MyMessage.MSG_ARG_KEY_CORRECT_NUM)
        test_sample_num = msg_params.get(MyMessage.MSG_ARG_KEY_TEST_SAMPLES_NUM)

        self.server_trainer.add_local_acc_result(sender_id-1, correct_num, test_sample_num)

        b_all_received = self.server_trainer.check_whether_all_receive()
        logging.info("b_all_received = " + str(b_all_received))
        if b_all_received:
            self.server_trainer.cal_acc_total(self.correct, self.test_sample_num)

            print('here')
            self.finish()
            return


    def send_message_init_config(self, receive_id, global_model_params):
        message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params)
        # message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index))
        self.send_message(message)


    def send_message_sync_model_to_client(self, receive_id, global_model_params):
        logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id)
        message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params)
        # message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index))
        self.send_message(message)

