import logging
import math

from fedml_api.distributed.fednas_extension.message_define import MyMessage
from fedml_core.distributed.client.client_manager import ClientManager
from fedml_core.distributed.communication.message import Message


class FedNASClientManager(ClientManager):
    def __init__(self, args, comm, rank, size, trainer):
        super().__init__(args, comm, rank, size)
        self.trainer = trainer
        self.num_rounds = args.comm_round
        logging.info("num_rounds = %d" % self.num_rounds)
        self.round_idx = 0

    def run(self):
        super().run()

    def register_message_receive_handlers(self):
        self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG,
                                              self.__handle_msg_client_receive_config)
        self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,
                                              self.__handle_msg_client_receive_model_from_server)
        self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_START_FINE_TUNING,
                                              self.__handle_msg_client_start_fine_tuning)

    def __handle_msg_client_receive_config(self, msg_params):
        logging.info("__handle_msg_client_receive_config")
        global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS)
        self.trainer.update_model(global_model_params)
        # self.trainer.update_arch(arch_params)

        self.round_idx = 0
        # start to train
        self.__train()

    def __handle_msg_client_receive_model_from_server(self, msg_params):
        process_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER)
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS)
        client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)

        # Load model if exists

        if process_id != 0:
            return
        self.trainer.update_dataset(int(client_index))

        if self.args.stage == "personalized_search":
            self.trainer.load_personal_model(self.args.path_of_local_model)  # first load personalized model if exists
            self.trainer.update_base_model(model_params)
        else:
            self.trainer.update_model(model_params)



        if self.args.stage == "fednas_search":
            self.trainer.update_arch(arch_params)

        self.round_idx += 1
        logging.info("client received message. self.round_idx = %d" % self.round_idx)
        self.trainer.update_training_progress(self.round_idx)

        self.__train()

    def __handle_msg_client_start_fine_tuning(self, msg_params):
        process_idx = msg_params.get(MyMessage.MSG_ARG_KEY_RECEIVER) - 1
        logging.info("start fine-tuning. process_idx = %d" % process_idx)
        # get client IDs by client_index
        logging.info("client_num_in_total = %d. client_num_per_round = %d" % (
        self.args.client_num_in_total, self.args.client_num_per_round))
        client_num_per_process = math.ceil(self.args.client_num_in_total / self.args.client_num_per_round)
        client_idx_start = process_idx * client_num_per_process
        client_idx_end = (process_idx + 1) * client_num_per_process \
            if process_idx < self.args.client_num_per_round - 1 else self.args.client_num_in_total

        logging.info("client_idx_start = %d, client_idx_end = %d" % (client_idx_start, client_idx_end))
        for client_idx in range(client_idx_start, client_idx_end):
            self.trainer.update_dataset(client_idx)
            self.__fine_tune()
        logging.info("Client Local Adaptation Finished!")
        # post_complete_message_to_sweep_process(self.args)
        # self.finish()

    def __train(self):
        logging.info("#######__train########### round_id = %d" % self.round_idx)
        if self.args.stage == "fednas_search" or self.args.stage == "personalized_search" or \
                self.args.stage == 'backbone_cell_search':
            weights, alphas, local_sample_num, \
            train_acc, train_loss, personalized_acc, \
            client_idx, flops, model_size = self.trainer.search()
            # logging.info("alphas received after training ")
            # logging.info(alphas)
        elif self.args.stage == 'train':
            weights, local_sample_num, train_acc, train_loss,\
                personalized_acc, client_idx = self.trainer.train()
            alphas = []
            flops = 0.0
            model_size = 0.0
        elif self.args.stage == 'per_train':
            weights, local_sample_num, train_acc, train_loss,\
                personalized_acc, client_idx, flops, model_size = self.trainer.per_train()
            alphas = []

        if self.args.stage == "personalized_search" or self.args.stage == 'fednas_search':
            self.trainer.save_personal_model(self.args.path_of_local_model)  # save locally adaptive model before sending to server
            # self.trainer.load_personal_model()
            # logging.info("Successfully loaded")
            # exit()

        self.__send_msg_fedavg_send_model_to_server(weights, alphas, local_sample_num,
                                                    train_acc, train_loss, personalized_acc,
                                                    client_idx, flops, model_size)

    def __fine_tune(self):
        logging.info("#######__fine_tune########### round_id = %d" % self.round_idx)
        # after fine-tune/local adaptation, make sure you save results of local models
        global_model_acc_after_finetunning = self.trainer.global_train()
        weights, alphas, local_sample_num, train_acc, train_loss, \
        acc_global_model_on_local_data, acc_personalized_model_on_local_data, \
        pi_params, personalized_arch, client_index = self.trainer.local_fine_tune()
        self.__send_msg_fedavg_send_fine_tune_result_to_server(weights, alphas, local_sample_num, train_acc, train_loss,
                                                               global_model_acc_after_finetunning,
                                                               acc_personalized_model_on_local_data,
                                                               pi_params, personalized_arch, client_index)
        logging.info("message of fine tunning sent to server")

    def __send_msg_fedavg_send_model_to_server(self, weights, alphas, local_sample_num, train_acc, train_loss,
                                               acc_personalized_model_on_local_data, client_index, flops, model_size):
        message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.rank, 0)
        message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
        message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, alphas)
        message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_ACC, train_acc)
        message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_LOSS, train_loss)
        message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, client_index)
        message.add_params(MyMessage.MSG_ARG_KEY_ACC_PERSONALIZED_MODEL_ON_LOCAL_DATA,
                           acc_personalized_model_on_local_data)
        message.add_params(MyMessage.MSG_ARG_KEY_FLOPS, flops)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_SIZE,
                           model_size)

        self.send_message(message)

    def __send_msg_fedavg_send_fine_tune_result_to_server(self, weights, alphas, local_sample_num, train_acc,
                                                          train_loss,
                                                          acc_global_model_on_local_data,
                                                          acc_personalized_model_on_local_data,
                                                          pi_params, personalized_arch, client_index):
        logging.info("__send_msg_fedavg_send_fine_tune_result_to_server")
        message = Message(MyMessage.MSG_TYPE_C2S_SEND_FINE_TUNE_RESULT_TO_SERVER, self.rank, 0)
        message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
        message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, alphas)
        message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_ACC, train_acc)
        message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_LOSS, train_loss)

        message.add_params(MyMessage.MSG_ARG_KEY_ACC_GLOBAL_MODEL_ON_LOCAL_DATA, acc_global_model_on_local_data)
        message.add_params(MyMessage.MSG_ARG_KEY_ACC_PERSONALIZED_MODEL_ON_LOCAL_DATA,
                           acc_personalized_model_on_local_data)

        message.add_params(MyMessage.MSG_ARG_KEY_PI_PARAM, pi_params)
        message.add_params(MyMessage.MSG_ARG_KEY_ARCH_CLIENT, personalized_arch)
        message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, client_index)
        self.send_message(message)
        # logging.info(message.to_string())
        logging.info("__send_msg_fedavg_send_fine_tune_result_to_server. END")

