import logging
import numpy as np
import os.path as osp

from .message_def import MyMessage
from fedml_core.distributed.client.client_manager import ClientManager
from fedml_core.distributed.communication.message import Message

class FedXDDClientMananger(ClientManager):
    def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"):
        super().__init__(args, comm, rank, size, backend)

        self.trainer = trainer
        self.num_rounds = args.comm_round
        self.round_idx = 0
        self.acc_record = []

    def run(self):
        super().run()

    ### need to revise
    def register_message_receive_handlers(self):
        self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG,
                                              self.handle_message_init)
        self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT,
                                              self.handle_message_receive_compact_model_from_server)

    def handle_message_init(self, msg_params):
        logging.info("handle_message_init. Rank = " + str(self.rank))
        self.round_idx = 0
        self.__train()

    ### need to revise
    def handle_message_receive_compact_model_from_server(self, msg_params):
        logging.info("handle_message_receive_compact_model_from_server. Rank = " + str(self.rank))
        global_model = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS)
        ###
        self.trainer.set_global_model_params(global_model)
        self.round_idx += 1
        acc, correct_num, test_sample_num = self.trainer.evaluation()
        self.acc_record.append(acc)
        if self.round_idx == self.num_rounds:
            if self.rank == 1:
                savename = self.args.client_record_name
                np.save(savename + '.npy', np.array(self.acc_record).reshape(1, -1))

                # if not osp.exists(savename+'.npy'):
                #     np.save(savename, np.array(self.acc_record).reshape(1,-1))
                # else:
                #     pre_acc_record = np.load(savename+'.npy', allow_pickle=True)
                #     pre_acc_record = np.concatenate((pre_acc_record, np.array(self.acc_record).reshape(1,-1)), axis=0)
                #     np.save(savename, pre_acc_record)

            self.args.out_file.write("#################finish training##############################" + '\n')
            self.args.out_file.flush()
            self.args.out_file.close()
            self.send_acc_to_server(0, correct_num, test_sample_num)

            # print("#################finish training##############################")

            # self.finish()
            return

        self.__train()
        # if self.round_idx == self.num_rounds:
        #     self.finish()

    def send_model_to_server(self, receive_id, weights, local_sample_num):
        message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_AND_DUAL, self.get_sender_id(), receive_id)
        message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights)
        message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num)
        self.send_message(message)

    def send_acc_to_server(self, receive_id, correct_num, test_sample_num):
        message = Message(MyMessage.MSG_TYPE_C2S_SEND_ACC, self.get_sender_id(), receive_id)
        message.add_params(MyMessage.MSG_ARG_KEY_CORRECT_NUM, correct_num)
        message.add_params(MyMessage.MSG_ARG_KEY_TEST_SAMPLES_NUM, test_sample_num)
        self.send_message(message)

    ### need to revise
    def __train(self):
        # logging.info("#######training########### round_id = %d" % self.round_idx)
        print("#######training########### round_id = %d" % self.round_idx, flush=True)
        weights, local_sample_num = self.trainer.train()
        # extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test = self.trainer.train()
        # logging.info("#################finish training##############################")
        self.send_model_to_server(0, weights, local_sample_num)

