import torch
import logging
import copy
import numpy as np

from federatedscope.core.message import Message
from federatedscope.core.workers.client import Client
from federatedscope.core.auxiliaries.utils import merge_dict

logger = logging.getLogger(__name__)


class GlobalContrastFLClient(Client):
    r"""
    GlobalContrastFL(Fedgc) Client receive aggregated model weight from
    server then update local weight; it also receive global loss from server
    to train model and update weight locally.
    """
    def _register_default_handlers(self):
        self.register_handlers('assign_client_id',
                               self.callback_funcs_for_assign_id)
        self.register_handlers('ask_for_join_in_info',
                               self.callback_funcs_for_join_in_info)
        self.register_handlers('address', self.callback_funcs_for_address)
        self.register_handlers('model_para',
                               self.callback_funcs_for_pred_embedding)
        self.register_handlers('global_loss',
                               self.callback_funcs_for_local_backward)
        self.register_handlers('ss_model_para',
                               self.callback_funcs_for_model_para)

        self.register_handlers('evaluate', self.callback_funcs_for_evaluate)
        self.register_handlers('finish', self.callback_funcs_for_finish)
        self.register_handlers('converged', self.callback_funcs_for_converged)

    def callback_funcs_for_local_backward(self, message: Message):
        round, sender, content = message.state, message.sender, message.content
        global_loss = content['global_loss']
        model_para = self.trainer.train_with_global_loss(global_loss)
        self.trainer.update(model_para)
        self.state = round
        sample_size = self.trainer.num_samples
        model_para = self.trainer.get_model_para()

        self.comm_manager.send(
            Message(msg_type='model_para',
                    sender=self.ID,
                    receiver=[sender],
                    state=self.state,
                    content=(sample_size, model_para)))

    def callback_funcs_for_pred_embedding(self, message: Message):
        round, sender, content = message.state, message.sender, message.content
        self.trainer.update(content)
        sample_size, model_para, results = self.trainer.train()
        self.state = round
        pred_embedding = self.trainer.get_train_pred_embedding()

        train_log_res = self._monitor.format_eval_res(results,
                                                      rnd=self.state,
                                                      role='Client #{}'.format(
                                                          self.ID),
                                                      return_raw=True)
        logger.info(train_log_res)

        self.comm_manager.send(
            Message(msg_type='pred_embedding',
                    sender=self.ID,
                    receiver=[sender],
                    state=self.state,
                    content=(pred_embedding)))
