import torch
import logging
import copy

from torch_geometric.loader import NeighborSampler

from federatedscope.core.message import Message
from federatedscope.core.workers.server import Server
from federatedscope.core.workers.client import Client
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
from federatedscope.core.data import ClientData

from federatedscope.gfl.trainer.nodetrainer import NodeMiniBatchTrainer
from federatedscope.gfl.model.fedsageplus import LocalSage_Plus, FedSage_Plus
from federatedscope.gfl.fedsageplus.utils import GraphMender, HideGraph
from federatedscope.gfl.fedsageplus.trainer import LocalGenTrainer, \
    FedGenTrainer

logger = logging.getLogger(__name__)


class FedSagePlusServer(Server):
    def __init__(self,
                 ID=-1,
                 state=0,
                 config=None,
                 data=None,
                 model=None,
                 client_num=5,
                 total_round_num=10,
                 device='cpu',
                 strategy=None,
                 **kwargs):
        r"""
        FedSage+ consists of three of training stages.
        Stage1: 0, local pre-train for generator.
        Stage2: -> 2 * fedgen_epoch, federated training for generator.
        Stage3: -> 2 * fedgen_epoch + total_round_num: federated training
        for GraphSAGE Classifier
        """
        super(FedSagePlusServer,
              self).__init__(ID, state, config, data, model, client_num,
                             total_round_num, device, strategy, **kwargs)

        assert self.model_num == 1, "Not supported multi-model for " \
                                    "FedSagePlusServer"

        # If state < fedgen_epoch and state % 2 == 0:
        #     Server receive [model, embedding, label]
        # If state < fedgen_epoch and state % 2 == 1:
        #     Server receive [gradient]
        self.fedgen_epoch = 2 * self._cfg.fedsageplus.fedgen_epoch
        self.total_round_num = total_round_num + self.fedgen_epoch
        self.grad_cnt = 0

    def _register_default_handlers(self):
        self.register_handlers('join_in', self.callback_funcs_for_join_in)
        self.register_handlers('join_in_info', self.callback_funcs_for_join_in)
        self.register_handlers('clf_para', self.callback_funcs_model_para)
        self.register_handlers('gen_para', self.callback_funcs_model_para)
        self.register_handlers('gradient', self.callback_funcs_gradient)
        self.register_handlers('metrics', self.callback_funcs_for_metrics)

    def callback_funcs_for_join_in(self, message: Message):
        if 'info' in message.msg_type:
            sender, info = message.sender, message.content
            for key in self._cfg.federate.join_in_info:
                assert key in info
            self.join_in_info[sender] = info
            logger.info('Server: Client #{:d} has joined in !'.format(sender))
        else:
            self.join_in_client_num += 1
            sender, address = message.sender, message.content
            if int(sender) == -1:  # assign number to client
                sender = self.join_in_client_num
                self.comm_manager.add_neighbors(neighbor_id=sender,
                                                address=address)
                self.comm_manager.send(
                    Message(msg_type='assign_client_id',
                            sender=self.ID,
                            receiver=[sender],
                            state=self.state,
                            content=str(sender)))
            else:
                self.comm_manager.add_neighbors(neighbor_id=sender,
                                                address=address)

            if len(self._cfg.federate.join_in_info) != 0:
                self.comm_manager.send(
                    Message(msg_type='ask_for_join_in_info',
                            sender=self.ID,
                            receiver=[sender],
                            state=self.state,
                            content=self._cfg.federate.join_in_info.copy()))

        if self.check_client_join_in():
            if self._cfg.federate.use_ss:
                self.broadcast_client_address()

            self.comm_manager.send(
                Message(msg_type='local_pretrain',
                        sender=self.ID,
                        receiver=list(self.comm_manager.neighbors.keys()),
                        state=self.state))

    def callback_funcs_gradient(self, message: Message):
        round, _, content = message.state, message.sender, message.content
        gen_grad, ID = content
        # For a new round
        if round not in self.msg_buffer['train'].keys():
            self.msg_buffer['train'][round] = dict()
        self.grad_cnt += 1
        # Sum up all grad from other client
        if ID not in self.msg_buffer['train'][round]:
            self.msg_buffer['train'][round][ID] = dict()
            for key in gen_grad.keys():
                self.msg_buffer['train'][round][ID][key] = torch.FloatTensor(
                    gen_grad[key].cpu())
        else:
            for key in gen_grad.keys():
                self.msg_buffer['train'][round][ID][key] += torch.FloatTensor(
                    gen_grad[key].cpu())
        self.check_and_move_on()

    def check_and_move_on(self, check_eval_result=False):
        client_IDs = [i for i in range(1, self.client_num + 1)]

        if check_eval_result:
            # all clients are participating in evaluation
            minimal_number = self.client_num
        else:
            # sampled clients are participating in training
            minimal_number = self.sample_client_num

        # Transmit model and embedding to get gradient back
        if self.check_buffer(
                self.state, self.client_num
        ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\
                % 2 == 0:
            # FedGen: we should wait for all messages
            for sender in self.msg_buffer['train'][self.state]:
                content = self.msg_buffer['train'][self.state][sender]
                gen_para, embedding, label = content
                receiver_IDs = client_IDs[:sender - 1] + client_IDs[sender:]
                self.comm_manager.send(
                    Message(msg_type='gen_para',
                            sender=self.ID,
                            receiver=receiver_IDs,
                            state=self.state + 1,
                            content=[gen_para, embedding, label, sender]))
                logger.info(f'\tServer: Transmit gen_para to'
                            f' {receiver_IDs} @{self.state//2}.')
            self.state += 1

        # Sum up gradient client-wisely and send back
        if self.check_buffer(
                self.state, self.client_num
        ) and self.state < self._cfg.fedsageplus.fedgen_epoch and self.state\
                % 2 == 1 and self.grad_cnt == self.client_num * (
                self.client_num - 1):
            for ID in self.msg_buffer['train'][self.state]:
                grad = self.msg_buffer['train'][self.state][ID]
                self.comm_manager.send(
                    Message(msg_type='gradient',
                            sender=self.ID,
                            receiver=[ID],
                            state=self.state + 1,
                            content=grad))
            # reset num of grad counter
            self.grad_cnt = 0
            self.state += 1

        if self.check_buffer(
                self.state, self.client_num
        ) and self.state == self._cfg.fedsageplus.fedgen_epoch:
            self.state += 1
            # Setup Clf_trainer for each client
            self.comm_manager.send(
                Message(msg_type='setup',
                        sender=self.ID,
                        receiver=list(self.comm_manager.neighbors.keys()),
                        state=self.state))

        if self.check_buffer(
                self.state, minimal_number, check_eval_result
        ) and self.state >= self._cfg.fedsageplus.fedgen_epoch:

            if not check_eval_result:  # in the training process
                # Get all the message
                train_msg_buffer = self.msg_buffer['train'][self.state]
                msg_list = list()
                for client_id in train_msg_buffer:
                    msg_list.append(train_msg_buffer[client_id])

                # Trigger the monitor here (for training)
                self._monitor.calc_model_metric(self.model.state_dict(),
                                                msg_list,
                                                rnd=self.state)

                # Aggregate
                agg_info = {
                    'client_feedback': msg_list,
                    'recover_fun': self.recover_fun
                }
                result = self.aggregator.aggregate(agg_info)
                self.model.load_state_dict(result)
                self.aggregator.update(result)

                self.state += 1
                if self.state % self._cfg.eval.freq == 0 and self.state != \
                        self.total_round_num:
                    #  Evaluate
                    logger.info(
                        'Server : Starting evaluation at round {:d}.'.format(
                            self.state))
                    self.eval()

                if self.state < self.total_round_num:
                    # Move to next round of training
                    logger.info(
                        f'----------- Starting a new training round(Round '
                        f'#{self.state}) -------------')
                    self.broadcast_model_para(
                        msg_type='model_para',
                        sample_client_num=self.sample_client_num)
                else:
                    # Final Evaluate
                    logger.info('Server: Training is finished! Starting '
                                'evaluation.')
                    self.eval()

            else:  # in the evaluation process
                # Get all the message & aggregate
                formatted_eval_res = self.merge_eval_results_from_all_clients()
                self.history_results = merge_dict_of_results(
                    self.history_results, formatted_eval_res)
                self.check_and_save()


class FedSagePlusClient(Client):
    def __init__(self,
                 ID=-1,
                 server_id=None,
                 state=-1,
                 config=None,
                 data=None,
                 model=None,
                 device='cpu',
                 strategy=None,
                 *args,
                 **kwargs):
        super(FedSagePlusClient,
              self).__init__(ID, server_id, state, config, data, model, device,
                             strategy, *args, **kwargs)
        self.data = data
        self.hide_data = HideGraph(self._cfg.fedsageplus.hide_portion)(
            data['data'])
        # Convert to `ClientData`
        self.hide_data = ClientData(self._cfg,
                                    train=[self.hide_data],
                                    val=[self.hide_data],
                                    test=[self.hide_data],
                                    data=self.hide_data)
        self.device = device
        self.sage_batch_size = 64
        self.gen = LocalSage_Plus(data['data'].x.shape[-1],
                                  self._cfg.model.out_channels,
                                  hidden=self._cfg.model.hidden,
                                  gen_hidden=self._cfg.fedsageplus.gen_hidden,
                                  dropout=self._cfg.model.dropout,
                                  num_pred=self._cfg.fedsageplus.num_pred)
        self.clf = model
        self.trainer_loc = LocalGenTrainer(self.gen,
                                           self.hide_data,
                                           self.device,
                                           self._cfg,
                                           monitor=self._monitor)

        self.register_handlers('clf_para', self.callback_funcs_for_model_para)
        self.register_handlers('local_pretrain',
                               self.callback_funcs_for_local_pre_train)
        self.register_handlers('gradient', self.callback_funcs_for_gradient)
        self.register_handlers('gen_para', self.callback_funcs_for_gen_para)
        self.register_handlers('setup', self.callback_funcs_for_setup_fedsage)

    def callback_funcs_for_local_pre_train(self, message: Message):
        round, sender, _ = message.state, message.sender, message.content
        # Local pre-train
        logger.info(f'\tClient #{self.ID} pre-train start...')
        for i in range(self._cfg.fedsageplus.loc_epoch):
            num_samples_train, _, _ = self.trainer_loc.train()
            logger.info(f'\tClient #{self.ID} local pre-train @Epoch {i}.')
        # Build fedgen base on locgen
        self.fedgen = FedSage_Plus(self.gen)
        # Build trainer for fedgen
        self.trainer_fedgen = FedGenTrainer(self.fedgen,
                                            self.hide_data,
                                            self.device,
                                            self._cfg,
                                            monitor=self._monitor)

        gen_para = self.fedgen.cpu().state_dict()
        embedding = self.trainer_fedgen.embedding()
        self.state = round
        logger.info(f'\tClient #{self.ID} pre-train finish!')
        # Start the training of fedgen
        self.comm_manager.send(
            Message(msg_type='gen_para',
                    sender=self.ID,
                    receiver=[sender],
                    state=self.state,
                    content=[
                        gen_para, embedding, self.hide_data['data'].num_missing
                    ]))
        logger.info(f'\tClient #{self.ID} send gen_para to Server #{sender}.')

    def callback_funcs_for_gen_para(self, message: Message):
        round, sender, content = message.state, message.sender, message.content
        gen_para, embedding, label, ID = content

        gen_grad = self.trainer_fedgen.cal_grad(self.data['data'], gen_para,
                                                embedding, label)
        self.state = round
        self.comm_manager.send(
            Message(msg_type='gradient',
                    sender=self.ID,
                    receiver=[sender],
                    state=self.state,
                    content=[gen_grad, ID]))
        logger.info(f'\tClient #{self.ID}: send gradient to Server #{sender}.')

    def callback_funcs_for_gradient(self, message):
        # Aggregate gen_grad on server
        round, sender, content = message.state, message.sender, message.content
        gen_grad = content
        self.trainer_fedgen.train()
        gen_para = self.trainer_fedgen.update_by_grad(gen_grad)
        embedding = self.trainer_fedgen.embedding()
        self.state = round
        self.comm_manager.send(
            Message(msg_type='gen_para',
                    sender=self.ID,
                    receiver=[sender],
                    state=self.state,
                    content=[
                        gen_para, embedding, self.hide_data['data'].num_missing
                    ]))
        logger.info(f'\tClient #{self.ID}: send gen_para to Server #{sender}.')

    def callback_funcs_for_setup_fedsage(self, message: Message):
        round, sender, _ = message.state, message.sender, message.content
        self.filled_data = GraphMender(
            model=self.fedgen,
            impaired_data=self.hide_data['data'].cpu(),
            original_data=self.data['data'])
        subgraph_sampler = NeighborSampler(
            self.filled_data.edge_index,
            sizes=[-1],
            batch_size=4096,
            shuffle=False,
            num_workers=self._cfg.dataloader.num_workers)
        fill_dataloader = {
            'data': self.filled_data,
            'train': NeighborSampler(
                self.filled_data.edge_index,
                node_idx=self.filled_data.train_idx,
                sizes=self._cfg.dataloader.sizes,
                batch_size=self.sage_batch_size,
                shuffle=self._cfg.dataloader.shuffle,
                num_workers=self._cfg.dataloader.num_workers),
            'val': subgraph_sampler,
            'test': subgraph_sampler
        }
        self._cfg.merge_from_list(
            ['dataloader.batch_size', self.sage_batch_size])
        self.trainer_clf = NodeMiniBatchTrainer(self.clf,
                                                fill_dataloader,
                                                self.device,
                                                self._cfg,
                                                monitor=self._monitor)
        sample_size, clf_para, results = self.trainer_clf.train()
        self.state = round
        logger.info(
            self._monitor.format_eval_res(results,
                                          rnd=self.state,
                                          role='Client #{}'.format(self.ID)))
        self.comm_manager.send(
            Message(msg_type='clf_para',
                    sender=self.ID,
                    receiver=[sender],
                    state=self.state,
                    content=(sample_size, clf_para)))

    def callback_funcs_for_model_para(self, message: Message):
        round, sender, content = message.state, message.sender, message.content
        self.trainer_clf.update(content)
        self.state = round
        sample_size, clf_para, results = self.trainer_clf.train()
        if self._cfg.federate.share_local_model and not \
                self._cfg.federate.online_aggr:
            clf_para = copy.deepcopy(clf_para)
        logger.info(
            self._monitor.format_eval_res(results,
                                          rnd=self.state,
                                          role='Client #{}'.format(self.ID)))
        self.comm_manager.send(
            Message(msg_type='clf_para',
                    sender=self.ID,
                    receiver=[sender],
                    state=self.state,
                    content=(sample_size, clf_para)))
