import logging
from time import sleep

import wandb

from fedml_api.distributed.fednas_extension.message_define import MyMessage
from fedml_api.distributed.fednas_extension.utils import post_complete_message_to_sweep_process
from fedml_api.distributed.fednas_extension.FedNAS_local_train import FedNAS_Local_Train
from fedml_core.distributed.communication.message import Message
from fedml_core.distributed.server.server_manager import ServerManager


class FedNASServerManager(ServerManager):
    def __init__(self, args, comm, rank, size, aggregator):
        super().__init__(args, comm, rank, size)
        self.round_num = args.comm_round
        self.round_idx = 0

        self.aggregator = aggregator

    def run(self):
        global_model = self.aggregator.get_model()
        global_model_params = global_model.state_dict()

        if self.args.stage == "fednas_search":
            global_arch_params = global_model.arch_parameters()
        else:
            global_arch_params = []

        for process_id in range(1, self.size):
            self.__send_initial_config_to_client(process_id, global_model_params, global_arch_params)
        super().run()

    def register_message_receive_handlers(self):
        self.register_message_receive_handler(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER,
                                              self.__handle_msg_server_receive_model_from_client_opt_send)

        self.register_message_receive_handler(MyMessage.MSG_TYPE_C2S_SEND_FINE_TUNE_RESULT_TO_SERVER,
                                              self.__handle_msg_server_receive_fine_tune_result_from_client)

    def __send_initial_config_to_client(self, process_id, global_model_params, global_arch_params):
        message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), process_id)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params)
        message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, global_arch_params)
        logging.info("MSG_TYPE_S2C_INIT_CONFIG. receiver: " + str(process_id))
        self.send_message(message)

    def __handle_msg_server_receive_model_from_client_opt_send(self, msg_params):
        process_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER)
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS)
        local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES)
        train_acc = msg_params.get(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_ACC)
        train_loss = msg_params.get(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_LOSS)
        client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)
        flops = msg_params.get(MyMessage.MSG_ARG_KEY_FLOPS)
        model_size = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_SIZE)
        acc_personalized_model_on_local_data = msg_params.get(MyMessage.MSG_ARG_KEY_ACC_PERSONALIZED_MODEL_ON_LOCAL_DATA)



        self.aggregator.add_local_trained_result(process_id - 1, model_params, arch_params, local_sample_number,
                                                 train_acc, train_loss, acc_personalized_model_on_local_data, client_index,
                                                 flops, model_size, self.round_idx)

        b_all_received = self.aggregator.check_whether_all_processes_receive()
        logging.info("==============Client index %d ROUND = %d =============" % (process_id - 1, self.round_idx))
        logging.info("b_all_received = " + str(b_all_received))
        if b_all_received:
            if self.args.stage == "fednas_search" or self.args.stage == 'backbone_cell_search':
                global_model_params, global_arch_params = self.aggregator.aggregate()
            else:
                global_model_params = self.aggregator.aggregate()
                global_arch_params = []

            if self.round_idx % self.args.frequency_of_the_test == 0:
                self.aggregator.evaluation(self.round_idx)


            # sampling clients: client number (1 million) >> worker number (200 processes)
            client_indexes = self.aggregator.client_sampling(self.round_idx, self.args.client_num_in_total,
                                                             self.args.client_num_per_round)
            # start the next round
            self.round_idx += 1
            if self.round_idx == self.round_num:  # if last round, then do local training
                logging.info("Global Training is Finished!")
                self.finish()
                return
            for receiver_id in range(1, self.size): # send global params to all clients?
                self.__send_model_to_client_message(receiver_id, global_model_params, global_arch_params,
                                                    client_indexes[receiver_id - 1])

    def __send_model_to_client_message(self, process_id, global_model_params, global_arch_params, client_index):
        message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, 0, process_id)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params)
        message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, global_arch_params)
        message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index))
        logging.info("__send_model_to_client_message. MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT. receiver: " + str(process_id))
        self.send_message(message)

    def __handle_msg_server_receive_fine_tune_result_from_client(self, msg_params):
        logging.info("__handle_msg_server_receive_fine_tune_result_from_client")
        process_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER)
        model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS)
        local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES)
        train_acc = msg_params.get(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_ACC)
        train_loss = msg_params.get(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_LOSS)

        pi_params = msg_params.get(MyMessage.MSG_ARG_KEY_PI_PARAM)
        personalized_arch = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_CLIENT)
        client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX)

        acc_global_model_on_local_data = msg_params.get(MyMessage.MSG_ARG_KEY_ACC_GLOBAL_MODEL_ON_LOCAL_DATA)
        acc_personalized_model_on_local_data = msg_params.get(MyMessage.MSG_ARG_KEY_ACC_PERSONALIZED_MODEL_ON_LOCAL_DATA)

        logging.info("client index = %d, acc_global_model_on_local_data = %f, "
                     "acc_personalized_model_on_local_data = %f" %
                     (process_id-1, acc_global_model_on_local_data, acc_personalized_model_on_local_data))
        self.aggregator.add_local_fine_tuned_result(process_id - 1, model_params, arch_params, local_sample_number,
                                                 train_acc, train_loss,
                                                 acc_global_model_on_local_data, acc_personalized_model_on_local_data,
                                                    pi_params, personalized_arch, client_index)
        b_all_received = self.aggregator.check_whether_all_clients_receive()
        logging.info("b_all_received = " + str(b_all_received))
        if b_all_received:
            logging.info(" Personal Evaluation ")
            self.aggregator.personal_evaluation(self.round_idx)

            ## Training of locally adapted models
            train_set, valid_set = self.aggregator.get_local_datasets()
            device = self.aggregator.get_device()
            local_trainer = FedNAS_Local_Train(self.args, train_set, valid_set, device)
            valid_acc_dict = local_trainer.local_train_and_infer()
            logging.info("Validation Accuracy " + str(valid_acc_dict))
            # Record Third Step's Accuracy
            self.aggregator.record_local_train_accuracy_step3(valid_acc_dict)
            wandb.finish()
            post_complete_message_to_sweep_process(self.args)
            sleep(5)
            # self.finish()

    def __send_msg_of_fine_tuning_to_client(self, process_id):
        message = Message(MyMessage.MSG_TYPE_S2C_START_FINE_TUNING, 0, process_id)
        logging.info("__send_msg_of_fine_tuning_to_client")
        self.send_message(message)
