import numpy as np
import logging
import copy

from federatedscope.core.workers import Client
from federatedscope.core.message import Message
from federatedscope.vertical_fl.Paillier import \
            abstract_paillier

logger = logging.getLogger(__name__)


class TreeClient(Client):
    def __init__(self,
                 ID=-1,
                 server_id=None,
                 state=0,
                 config=None,
                 data=None,
                 model=None,
                 device='cpu',
                 strategy=None,
                 *args,
                 **kwargs):

        super(TreeClient,
              self).__init__(ID, server_id, state, config, data, model, device,
                             strategy, *args, **kwargs)

        self.data = data
        self.own_label = ('y' in data['train'])
        self.msg_buffer = {'train': {}, 'eval': {}}
        self.client_num = self._cfg.federate.client_num

        if self._cfg.vertical.eval_protection == 'he':
            keys = abstract_paillier.generate_paillier_keypair(
                n_length=self._cfg.vertical.key_size)
            self.public_key, self.private_key = keys

        self.feature_order = None
        self.merged_feature_order = None

        self.feature_partition = np.diff(self._cfg.vertical.dims, prepend=0)
        self.total_num_of_feature = self._cfg.vertical.dims[-1]
        self.num_of_feature = self.feature_partition[self.ID - 1]
        self.feature_importance = [0] * self.num_of_feature

        self._init_data_related_var()

        self.register_handlers('model_para', self.callback_func_for_model_para)
        self.register_handlers('data_sample',
                               self.callback_func_for_data_sample)
        self.register_handlers('training_info',
                               self.callback_func_for_training_info)
        self.register_handlers('finish', self.callback_func_for_finish)

    def train(self, tree_num, node_num=None, training_info=None):
        raise NotImplementedError

    def eval(self, tree_num):
        raise NotImplementedError

    def _init_data_related_var(self):

        self.trainer._init_for_train()
        self.test_x = None
        self.test_y = None

    # all clients receive model para, and initial a tree list,
    # each contains self.num_of_trees trees
    # label-owner initials y_hat
    # label-owner sends "sample data" to others
    def callback_func_for_model_para(self, message: Message):
        self.state = message.state

        self.trainer.prepare_for_train()
        if self.own_label:
            batch_index, feature_order_info = self.trainer.fetch_train_data()
            self.start_a_new_training_round(batch_index,
                                            feature_order_info,
                                            tree_num=0)

    # other clients receive the data-sample information
    def callback_func_for_data_sample(self, message: Message):
        self.state = message.state
        batch_index, sender = message.content, message.sender
        _, feature_order_info = self.trainer.fetch_train_data(
            index=batch_index)
        self.feature_order = feature_order_info['feature_order']

        if self._cfg.vertical.mode == 'feature_gathering':
            training_info = feature_order_info
        elif self._cfg.vertical.mode == 'label_scattering':
            training_info = 'dummy_info'
        else:
            raise TypeError(
                f'The expected types of vertical.mode include '
                f'["label_scattering", "feature_gathering"], but got '
                f'{self._cfg.vertical.mode}.')

        self.comm_manager.send(
            Message(msg_type='training_info',
                    sender=self.ID,
                    state=self.state,
                    receiver=[sender],
                    content=training_info))

    def callback_func_for_training_info(self, message: Message):
        feature_order_info, sender = message.content, message.sender
        self.msg_buffer['train'][sender] = feature_order_info
        self.check_and_move_on()

    def callback_func_for_finish(self, message: Message):
        logger.info(
            f"================= client {self.ID} received finish message "
            f"=================")
        # self._monitor.finish_fl()

    def start_a_new_training_round(self,
                                   batch_index,
                                   feature_order_info,
                                   tree_num=0):
        self.msg_buffer['train'].clear()
        self.feature_order = feature_order_info['feature_order']
        self.msg_buffer['train'][self.ID] = feature_order_info \
            if self._cfg.vertical.mode == 'feature_gathering' else 'dummy_info'
        self.state = tree_num
        receiver = [
            each for each in list(self.comm_manager.neighbors.keys())
            if each != self.server_id
        ]
        send_message = Message(msg_type='data_sample',
                               sender=self.ID,
                               state=self.state,
                               receiver=receiver,
                               content=batch_index)
        self.comm_manager.send(send_message)

    def check_and_move_on(self):
        if len(self.msg_buffer['train']) == self.client_num:
            received_training_infos = copy.deepcopy(self.msg_buffer['train'])
            self.msg_buffer['train'].clear()
            self.train(tree_num=self.state,
                       training_info=received_training_infos)
