import copy
import json
import os
import logging
from itertools import product
import pickle

import torch.nn
import yaml

from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.autotune.pfedhpo.utils import *
from federatedscope.autotune.utils import parse_search_space

logger = logging.getLogger(__name__)


class pFedHPOFLServer(Server):
    def __init__(self,
                 ID=-1,
                 state=0,
                 config=None,
                 data=None,
                 model=None,
                 client_num=5,
                 total_round_num=10,
                 device='cpu',
                 strategy=None,
                 **kwargs):

        # initialize action space and the policy
        with open(config.hpo.pfedhpo.ss, 'r') as ips:
            ss = yaml.load(ips, Loader=yaml.FullLoader)

        if next(iter(ss.keys())).startswith('arm'):
            # This is a flattened action space
            # ensure the order is unchanged
            ss = sorted([(int(k[3:]), v) for k, v in ss.items()],
                        key=lambda x: x[0])
            self._grid = []
            self._cfsp = [[tp[1] for tp in ss]]
        else:
            # This is not a flat search space
            # be careful for the order
            self._grid = sorted(ss.keys())
            self._cfsp = [ss[pn] for pn in self._grid]

        super(pFedHPOFLServer,
              self).__init__(ID, state, config, data, model, client_num,
                             total_round_num, device, strategy, **kwargs)
        os.makedirs(self._cfg.hpo.working_folder, exist_ok=True)

        self.train_anchor = self._cfg.hpo.pfedhpo.train_anchor
        self.discrete = self._cfg.hpo.pfedhpo.discrete

        # prepare search space and bounds
        self._ss = parse_search_space(self._cfg.hpo.pfedhpo.ss)
        self.dim = len(self._ss)
        self.bounds = np.asarray([(0., 1.) for _ in self._ss])
        self.pbounds = {}
        if not self.discrete:
            for k, v in self._ss.items():
                if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
                    raise ValueError("Unsupported hyper type {}".format(
                        type(v)))
                else:
                    if v.log:
                        l, u = np.log10(v.lower), np.log10(v.upper)
                    else:
                        l, u = v.lower, v.upper
                    self.pbounds[k] = (l, u)
        else:
            for k, v in self._ss.items():
                if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
                    if hasattr(v, 'choices'):
                        self.pbounds[k] = list(v.choices)
                    else:
                        raise ValueError("Unsupported hyper type {}".format(
                            type(v)))
                else:
                    if v.log:
                        l, u = np.log10(v.lower), np.log10(v.upper)
                    else:
                        l, u = v.lower, v.upper
                    N_samp = 10
                    samp = []
                    for i in range(N_samp):
                        samp.append((u - l) / N_samp * i + l)
                    self.pbounds[k] = samp

        # prepare hyper-net
        self.client2idx = None

        if not self.train_anchor:
            hyper_enc = torch.load(
                os.path.join(self._cfg.hpo.working_folder,
                             'hyperNet_encoding.pt'))
            if self._cfg.data.type == 'mini-graph-dc':
                dim = 60
            elif 'cifar' in str(self._cfg.data.type).lower():
                dim = 60
            elif 'femnist' in str(self._cfg.data.type).lower():
                dim = 124
            elif 'twitter' in str(self._cfg.data.type).lower():
                dim = 100
            else:
                raise NotImplementedError

            self.client_encoding = torch.ones(client_num, dim)
            if not self.discrete:
                self.HyperNet = HyperNet(encoding=self.client_encoding,
                                         num_params=len(self.pbounds),
                                         n_clients=client_num,
                                         device=self._cfg.device,
                                         var=0.01).to(self._cfg.device)
            else:
                self.HyperNet = DisHyperNet(
                    encoding=self.client_encoding,
                    cands=self.pbounds,
                    n_clients=client_num,
                    device=self._cfg.device,
                ).to(self._cfg.device)

            self.HyperNet.load_state_dict(hyper_enc['hyperNet'])
            self.HyperNet.eval()
            if not self.discrete:
                self.raw_params = self.HyperNet()[0].detach().cpu().numpy()
            else:
                self.logits = self.HyperNet()[0]

    def callback_funcs_model_para(self, message: Message):
        round, sender, content = message.state, message.sender, message.content
        self.sampler.change_state(sender, 'idle')
        # For a new round
        if round not in self.msg_buffer['train'].keys():
            self.msg_buffer['train'][round] = dict()

        self.msg_buffer['train'][round][sender] = content

        if self._cfg.federate.online_aggr:
            try:
                self.aggregator.inc(tuple(content[0:2]))
            except:
                pass

        return self.check_and_move_on()

    def broadcast_model_para(self,
                             msg_type='model_para',
                             sample_client_num=-1,
                             filter_unseen_clients=True):
        """
        To broadcast the message to all clients or sampled clients

        Arguments:
            msg_type: 'model_para' or other user defined msg_type
            sample_client_num: the number of sampled clients in the broadcast
                behavior. And sample_client_num = -1 denotes to broadcast to
                all the clients.
            filter_unseen_clients: whether filter out the unseen clients that
                do not contribute to FL process by training on their local
                data and uploading their local model update. The splitting is
                useful to check participation generalization gap in [ICLR'22,
                What Do We Mean by Generalization in Federated Learning?]
                You may want to set it to be False when in evaluation stage
        """

        if self.train_anchor:
            ckpt_path = os.path.join(
                self._cfg.hpo.working_folder,
                'temp_model_round_%d.pt' % (int(self.state)))
            torch.save(self.model.state_dict(), ckpt_path)

        if filter_unseen_clients:
            # to filter out the unseen clients when sampling
            self.sampler.change_state(self.unseen_clients_id, 'unseen')

        if sample_client_num > 0:
            receiver = self.sampler.sample(size=sample_client_num)
        else:
            # broadcast to all clients
            receiver = list(self.comm_manager.neighbors.keys())
            if msg_type == 'model_para':
                self.sampler.change_state(receiver, 'working')

        if self._noise_injector is not None and msg_type == 'model_para':
            # Inject noise only when broadcast parameters
            for model_idx_i in range(len(self.models)):
                num_sample_clients = [
                    v["num_sample"] for v in self.join_in_info.values()
                ]
                self._noise_injector(self._cfg, num_sample_clients,
                                     self.models[model_idx_i])

        skip_broadcast = self._cfg.federate.method in ["local", "global"]
        if self.model_num > 1:
            model_para = [{} if skip_broadcast else model.state_dict()
                          for model in self.models]
        else:
            model_para = {} if skip_broadcast else self.model.state_dict()

        if not self.client2idx:
            client2idx = {}
            _all_clients = list(self.comm_manager.neighbors.keys())
            for i, k in zip(range(len(_all_clients)), _all_clients):
                client2idx[k] = i
            self.client2idx = client2idx

        for rcv_idx in receiver:
            if self.train_anchor:
                sampled_cfg = None
            else:
                if not self.discrete:
                    sampled_cfg = map_value_to_param(
                        self.raw_params[self.client2idx[rcv_idx]],
                        self.pbounds, self._ss)
                else:
                    sampled_cfg = {}

                    for i, (k, v) in zip(range(len(self.pbounds)),
                                         self.pbounds.items()):
                        probs = self.logits[i][self.client2idx[rcv_idx]]
                        p = v[torch.argmax(probs).item()]

                        if hasattr(self._ss[k], 'log') and self._ss[k].log:
                            p = 10**p
                        if 'int' in str(type(self._ss[k])).lower():
                            sampled_cfg[k] = int(p)
                        else:
                            sampled_cfg[k] = float(p)

            content = {'model_param': model_para, 'hyper_param': sampled_cfg}
            self.comm_manager.send(
                Message(msg_type=msg_type,
                        sender=self.ID,
                        receiver=[rcv_idx],
                        state=self.state,
                        content=content))

        if self._cfg.federate.online_aggr:
            try:
                for idx in range(self.model_num):
                    self.aggregators[idx].reset()
            except:
                pass

        if filter_unseen_clients:
            # restore the state of the unseen clients within sampler
            self.sampler.change_state(self.unseen_clients_id, 'seen')

    def save_res(self, feedbacks):
        feedbacks = {'round': self.state, 'results': feedbacks}
        line = str(feedbacks) + "\n"
        with open(
                os.path.join(self._cfg.hpo.working_folder,
                             'anchor_eval_results.log'), "a") as outfile:
            outfile.write(line)

    def check_and_move_on(self,
                          check_eval_result=False,
                          min_received_num=None):
        """
        To check the message_buffer, when enough messages are receiving,
        trigger some events (such as perform aggregation, evaluation,
        and move to the next training round)
        """
        if min_received_num is None:
            min_received_num = self._cfg.federate.sample_client_num
        assert min_received_num <= self.sample_client_num

        if check_eval_result:
            min_received_num = len(list(self.comm_manager.neighbors.keys()))

        move_on_flag = True  # To record whether moving to a new training
        # round or finishing the evaluation
        if self.check_buffer(self.state, min_received_num, check_eval_result):

            if not check_eval_result:  # in the training process
                mab_feedbacks = dict()
                # Get all the message
                train_msg_buffer = self.msg_buffer['train'][self.state]
                for model_idx in range(self.model_num):
                    model = self.models[model_idx]
                    aggregator = self.aggregators[model_idx]
                    msg_list = list()
                    for client_id in train_msg_buffer:
                        if self.model_num == 1:
                            msg_list.append(
                                tuple(train_msg_buffer[client_id][0:2]))
                        else:
                            train_data_size, model_para_multiple = \
                                train_msg_buffer[client_id][0:2]
                            msg_list.append((train_data_size,
                                             model_para_multiple[model_idx]))

                        # collect feedbacks for updating the policy
                        if model_idx == 0:
                            mab_feedbacks[client_id] = train_msg_buffer[
                                client_id][2]

                    # Trigger the monitor here (for training)
                    if 'dissim' in self._cfg.eval.monitoring:
                        from federatedscope.core.auxiliaries.utils import \
                            calc_blocal_dissim
                        # TODO: fix load_state_dict
                        B_val = calc_blocal_dissim(
                            model.load_state_dict(strict=False), msg_list)
                        formatted_eval_res = self._monitor.format_eval_res(
                            B_val, rnd=self.state, role='Server #')
                        logger.info(formatted_eval_res)

                    # Aggregate
                    agg_info = {
                        'client_feedback': msg_list,
                        'recover_fun': self.recover_fun
                    }
                    result = aggregator.aggregate(agg_info)
                    model.load_state_dict(result, strict=False)

                self.state += 1
                #  Evaluate
                logger.info(
                    'Server: Starting evaluation at round {:d}.'.format(
                        self.state))
                self.eval()

                if self.state < self.total_round_num:
                    # Move to next round of training
                    logger.info(
                        f'----------- Starting a new training round (Round '
                        f'#{self.state}) -------------')
                    # Clean the msg_buffer
                    self.msg_buffer['train'][self.state - 1].clear()

                    self.broadcast_model_para(
                        msg_type='model_para',
                        sample_client_num=self.sample_client_num)
                else:
                    # Final Evaluate
                    logger.info('Server: Training is finished! Starting '
                                'evaluation.')
                    self.eval()

            else:  # in the evaluation process
                # Get all the message & aggregate
                logger.info('-' * 30)
                formatted_eval_res = self.merge_eval_results_from_all_clients()
                self.history_results = merge_dict(self.history_results,
                                                  formatted_eval_res)
                self.check_and_save()

            if self.train_anchor and self.history_results:
                with open(
                        os.path.join(self._cfg.hpo.working_folder,
                                     'anchor_eval_results.json'), 'w') as f:
                    json.dump(self.history_results, f)
        else:
            move_on_flag = False

        return move_on_flag
