import logging
import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../../FedML")))

try:
    from fedml_core.distributed.client.client_manager import ClientManager
    from fedml_core.distributed.communication.message import Message
except ImportError:
    from FedML.fedml_core.distributed.client.client_manager import ClientManager
    from FedML.fedml_core.distributed.communication.message import Message
from .message_define import MyMessage
from .utils import transform_list_to_tensor


class FedSSLClientManager(ClientManager):
    def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"):
        super().__init__(args, comm, rank, size, backend)
        self.trainer = trainer
        self.num_rounds = args.comm_round
        self.round_idx = 0

        self.global_model_on_local_accuracy = 0.0
        self.personalized_accuracy = 0.0

    def run(self):
        super().run()

    def register_message_receive_handlers(self):
        self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG,
                                              self.handle_message_init)
        self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,
                                              self.handle_message_receive_model_from_server)

    def handle_message_init(self, msg_params):
        global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)

        if self.args.is_mobile == 1:
            global_model_params = transform_list_to_tensor(global_model_params)

        self.trainer.update_dataset(int(client_index))
        self.trainer.update_model(global_model_params)
        self.round_idx = 0
        self.__train()

    def start_training(self):
        self.round_idx = 0
        self.__train()

    def handle_message_receive_model_from_server(self, msg_params):
        logging.info("handle_message_receive_model_from_server.")
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)

        if self.args.is_mobile == 1:
            model_params = transform_list_to_tensor(model_params)

        self.trainer.update_model(model_params)

        if self.args.pssl_optimizer == "FedAvg" and self.args.ssl_is_linear_eval == 0\
                and self.args.using_personalized_data:
            self.global_model_on_local_accuracy = self.trainer.test_global_model_on_local_data(self.round_idx)
        else:
            self.global_model_on_local_accuracy = 0.0

        # start a new round of training
        self.trainer.update_dataset(int(client_index))
        self.round_idx += 1
        self.__train()

        if self.round_idx == self.num_rounds - 1:
            def post_complete_message_to_sweep_process(args):
                logging.info("post_complete_message_to_sweep_process")
                pipe_path = "./ssfl"
                if not os.path.exists(pipe_path):
                    os.mkfifo(pipe_path)
                pipe_fd = os.open(pipe_path, os.O_WRONLY)

                with os.fdopen(pipe_fd, 'w') as pipe:
                    pipe.write("training is finished! \n%s" % (str(args)))

            post_complete_message_to_sweep_process(self.args)
            self.finish()

    def __train(self):
        logging.info("#######training########### round_id = %d, pssl_optimizer = %s " %
                     (self.round_idx, self.args.pssl_optimizer))
        weights, local_sample_num, averaged_loss = self.trainer.train(self.round_idx)
        if self.args.pssl_optimizer != "FedAvg" or self.args.ssl_is_linear_eval == 1:
            if self.round_idx % self.args.frequency_of_the_test == 0 and self.args.using_personalized_data == 1:
                self.personalized_accuracy = self.trainer.test_personalized_model_on_local_data(self.round_idx)
        else:
            self.personalized_accuracy = 0.0

        self.send_model_to_server(0, weights, local_sample_num, averaged_loss,
                                  self.personalized_accuracy, self.global_model_on_local_accuracy)

    def send_model_to_server(self, receive_id, weights, local_sample_num, averaged_loss,
                             personalized_accuracy, global_model_on_local_accuracy):
        message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
        message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
        message.add_params(MyMessage.MSG_ARG_KEY_AVERAGED_LOSS, averaged_loss)
        message.add_params(MyMessage.MSG_ARG_KEY_ACC_PER, personalized_accuracy)
        message.add_params(MyMessage.MSG_ARG_KEY_ACC_GLOBAL_MODEL_ON_LOCAL_DATA_ACC, global_model_on_local_accuracy)
        self.send_message(message)
