import torch
import logging
import copy
import numpy as np

from federatedscope.core.message import Message
from federatedscope.core.workers.server import Server
from federatedscope.core.auxiliaries.utils import merge_dict
from federatedscope.cl.fedgc.utils import global_NT_xentloss

logger = logging.getLogger(__name__)


class GlobalContrastFLServer(Server):
    r"""
    GlobalContrastFL(Fedgc) Server contain two part in training: Fedavg
    aggragator for client model weight and calculate global loss from
    all sampled client embedding then broadcast all client to train model.
    """
    def __init__(self,
                 ID=-1,
                 state=0,
                 config=None,
                 data=None,
                 model=None,
                 client_num=5,
                 total_round_num=10,
                 device='cpu',
                 strategy=None,
                 **kwargs):
        super(GlobalContrastFLServer,
              self).__init__(ID, state, config, data, model, client_num,
                             total_round_num, device, strategy, **kwargs)
        # Initial seqs_embedding
        self.seqs_embedding = {
            idx: ()
            for idx in range(1, self._cfg.federate.client_num + 1)
        }
        self.loss_list = {
            idx: 0
            for idx in range(1, self._cfg.federate.client_num + 1)
        }

    def _register_default_handlers(self):
        self.register_handlers('join_in', self.callback_funcs_for_join_in)
        self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
        self.register_handlers('model_para', self.callback_funcs_model_para)
        self.register_handlers('metrics', self.callback_funcs_for_metrics)
        self.register_handlers('pred_embedding',
                               self.callback_funcs_global_loss)

    def check_and_move_on_for_global_loss(self):

        minimal_number = self.sample_client_num

        if self.check_buffer(self.state,
                             minimal_number,
                             check_eval_result=False):

            # Receiving enough feedback in the training process

            # Get all the message
            train_msg_buffer = self.msg_buffer['train'][self.state]
            for model_idx in range(self.model_num):
                model = self.models[model_idx]
                msg_list = list()
                for client_id in train_msg_buffer:
                    if self.model_num == 1:
                        pred_embedding = train_msg_buffer[client_id]
                        self.seqs_embedding[client_id] = pred_embedding
                    else:
                        raise ValueError(
                            'GlobalContrastFL server not support multi-model.')

                global_loss_fn = global_NT_xentloss(device=self.device)
                for client_id in train_msg_buffer:
                    z1 = self.seqs_embedding[client_id][0]
                    z2 = self.seqs_embedding[client_id][1]
                    others_z2 = [
                        self.seqs_embedding[other_client_id][1]
                        for other_client_id in train_msg_buffer
                        if other_client_id != client_id
                    ]
                    self.loss_list[client_id] = global_loss_fn(
                        z1, z2, others_z2)
                    logger.info(f'client {client_id}'
                                f'global_loss:{self.loss_list[client_id]}')

            self.state += 1
            if self.state <= self.total_round_num:

                for client_id in train_msg_buffer:

                    msg_list = {
                        'global_loss': self.loss_list[client_id],
                    }

                    self.comm_manager.send(
                        Message(msg_type='global_loss',
                                sender=self.ID,
                                receiver=[client_id],
                                state=self.state,
                                content=msg_list))

    def check_and_move_on(self,
                          check_eval_result=False,
                          min_received_num=None):
        """
        To check the message_buffer. When enough messages are receiving,
        some events (such as perform aggregation, evaluation, and move to
        the next training round) would be triggered.

        Arguments:
            check_eval_result (bool): If True, check the message buffer for
            evaluation; and check the message buffer for training otherwise.
        """
        if min_received_num is None:
            if self._cfg.asyn.use:
                min_received_num = self._cfg.asyn.min_received_num
            else:
                min_received_num = self._cfg.federate.sample_client_num
        assert min_received_num <= self.sample_client_num

        if check_eval_result and self._cfg.federate.mode.lower(
        ) == "standalone":
            # in evaluation stage and standalone simulation mode, we assume
            # strong synchronization that receives responses from all clients
            min_received_num = len(self.comm_manager.get_neighbors().keys())

        move_on_flag = True  # To record whether moving to a new training
        # round or finishing the evaluation
        if self.check_buffer(self.state, min_received_num, check_eval_result):
            if not check_eval_result:
                # Receiving enough feedback in the training process
                aggregated_num = self._perform_federated_aggregation()

                if self.state % self._cfg.eval.freq == 0 and self.state != \
                        self.total_round_num:
                    #  Evaluate
                    logger.info(f'Server: Starting evaluation at the end '
                                f'of round {self.state - 1}.')
                    self.eval()

                if self.state < self.total_round_num:
                    # Move to next round of training
                    logger.info(
                        f'----------- Starting a new training round (Round '
                        f'#{self.state}) -------------')
                    # Clean the msg_buffer
                    self.msg_buffer['train'][self.state - 1].clear()
                    self.msg_buffer['train'][self.state] = dict()
                    self.staled_msg_buffer.clear()
                    # Start a new training round
                    self._start_new_training_round(aggregated_num)
                else:
                    # Final Evaluate
                    logger.info('Server: Training is finished! Starting '
                                'evaluation.')
                    self.eval()

            else:
                # Receiving enough feedback in the evaluation process
                self._merge_and_format_eval_results()

        else:
            move_on_flag = False

        return move_on_flag

    def callback_funcs_global_loss(self, message: Message):
        """
        The handling function for receiving model embeddings, which triggers
            check_and_move_on (calculate global loss when enough feedback has
            been received).

        Arguments:
            message: The received message, which includes sender, receiver,
                state, and content. More detail can be found in
                federatedscope.core.message
        """
        if self.is_finish:
            return 'finish'

        round = message.state
        sender = message.sender
        timestamp = message.timestamp
        content = message.content
        self.sampler.change_state(sender, 'idle')

        # update the currency timestamp according to the received message
        assert timestamp >= self.cur_timestamp  # for test
        self.cur_timestamp = timestamp

        if round == self.state:
            if round not in self.msg_buffer['train']:
                self.msg_buffer['train'][round] = dict()
            # Save the messages in this round
            self.msg_buffer['train'][round][sender] = content
        elif round >= self.state - self.staleness_toleration:
            # Save the staled messages
            self.staled_msg_buffer.append((round, sender, content))

        move_on_flag = self.check_and_move_on_for_global_loss()

        return move_on_flag

    def callback_funcs_model_para(self, message: Message):
        """
        The handling function for receiving model parameters, which triggers
            check_and_move_on (perform aggregation when enough feedback has
            been received).
        This handling function is widely used in various FL courses.

        Arguments:
            message: The received message, which includes sender, receiver,
                state, and content. More detail can be found in
                federatedscope.core.message
        """
        if self.is_finish:
            return 'finish'

        round = message.state
        sender = message.sender
        timestamp = message.timestamp
        content = message.content
        self.sampler.change_state(sender, 'idle')

        # update the currency timestamp according to the received message
        assert timestamp >= self.cur_timestamp  # for test
        self.cur_timestamp = timestamp

        if round == self.state:
            if round not in self.msg_buffer['train']:
                self.msg_buffer['train'][round] = dict()
            # Save the messages in this round
            self.msg_buffer['train'][round][sender] = content
        elif round >= self.state - self.staleness_toleration:
            # Save the staled messages
            self.staled_msg_buffer.append((round, sender, content))
        else:
            # Drop the out-of-date messages
            logger.info(f'Drop a out-of-date message from round #{round}')
            self.dropout_num += 1

        if self._cfg.federate.online_aggr:
            self.aggregator.inc(content[:2])

        move_on_flag = self.check_and_move_on()
        if self._cfg.asyn.use and self._cfg.asyn.broadcast_manner == \
                'after_receiving':
            self.broadcast_model_para(msg_type='model_para',
                                      sample_client_num=1)

        return move_on_flag
