import os
import json
import logging
import torch.nn
import yaml
from torch.nn import functional as F

try:
    import habana_frameworks.torch.core as htcore
except ImportError:
    htcore = None

from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.autotune.pfedhpo.utils import *
from federatedscope.autotune.utils import parse_search_space

logger = logging.getLogger(__name__)


class pFedHPOServer(Server):
    def __init__(self,
                 ID=-1,
                 state=0,
                 config=None,
                 data=None,
                 model=None,
                 client_num=5,
                 total_round_num=10,
                 device='cpu',
                 strategy=None,
                 **kwargs):

        # initialize action space and the policy
        with open(config.hpo.pfedhpo.ss, 'r') as ips:
            ss = yaml.load(ips, Loader=yaml.FullLoader)

        if next(iter(ss.keys())).startswith('arm'):
            # This is a flattened action space
            # ensure the order is unchanged
            ss = sorted([(int(k[3:]), v) for k, v in ss.items()],
                        key=lambda x: x[0])
            self._grid = []
            self._cfsp = [[tp[1] for tp in ss]]
        else:
            # This is not a flat search space
            # be careful for the order
            self._grid = sorted(ss.keys())
            self._cfsp = [ss[pn] for pn in self._grid]

        super(pFedHPOServer,
              self).__init__(ID, state, config, data, model, client_num,
                             total_round_num, device, strategy, **kwargs)

        os.makedirs(self._cfg.hpo.working_folder, exist_ok=True)
        self.discrete = self._cfg.hpo.pfedhpo.discrete
        # prepare search space and bounds
        self._ss = parse_search_space(self._cfg.hpo.pfedhpo.ss)
        self.dim = len(self._ss)
        self.bounds = np.asarray([(0., 1.) for _ in self._ss])
        self.pbounds = {}

        if not self.discrete:
            for k, v in self._ss.items():
                if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
                    raise ValueError("Unsupported hyper type {}".format(
                        type(v)))
                else:
                    if v.log:
                        l, u = np.log10(v.lower), np.log10(v.upper)
                    else:
                        l, u = v.lower, v.upper
                    self.pbounds[k] = (l, u)
        else:
            for k, v in self._ss.items():
                if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
                    if hasattr(v, 'choices'):
                        self.pbounds[k] = list(v.choices)
                    else:
                        raise ValueError("Unsupported hyper type {}".format(
                            type(v)))
                else:
                    if v.log:
                        l, u = np.log10(v.lower), np.log10(v.upper)
                    else:
                        l, u = v.lower, v.upper
                    N_samp = 10
                    samp = []
                    for i in range(N_samp):
                        samp.append((u - l) / N_samp * i + l)
                    self.pbounds[k] = samp

        # prepare hyper-net
        self.client2idx = None

        if not self.discrete:
            self.var = 0.01
            dist = MultivariateNormal(
                loc=torch.zeros(len(self.pbounds)),
                covariance_matrix=torch.eye(len(self.pbounds)) * self.var)
            self.logprob_max = dist.log_prob(dist.sample() * 0)
        else:
            self.logprob_max = 1.

        encoding_tensor = []
        for i in range(self._cfg.federate.client_num + 1):
            p = os.path.join(self._cfg.hpo.working_folder,
                             'client_%d_encoding.pt' % i)
            if os.path.exists(p):
                t = torch.load(p)
                encoding_tensor.append(t)
        encoding_tensor = torch.stack(encoding_tensor)

        if not self.discrete:
            self.HyperNet = HyperNet(encoding=encoding_tensor,
                                     num_params=len(self.pbounds),
                                     n_clients=client_num,
                                     device=self._cfg.device,
                                     var=self.var).to(self._cfg.device)
        else:
            self.HyperNet = DisHyperNet(
                encoding=encoding_tensor, cands=self.pbounds,
                n_clients=client_num, device=self._cfg.device,)\
                .to(self._cfg.device)

        self.saved_models = [None] * self._cfg.hpo.pfedhpo.\
            target_fl_total_round
        self.opt_params = self.HyperNet.EncNet.parameters()

        self.opt = torch.optim.Adam([
            {
                'params': self.HyperNet.EncNet.parameters(),
                'lr': 0.001,
                'weight_decay': 1e-4
            },
        ])

        with open(
                os.path.join(self._cfg.hpo.working_folder,
                             'anchor_eval_results.json'), 'r') as f:
            self.anchor_res = json.load(f)
        self.anchor_res_smooth = None

    def broadcast_model_para(self,
                             msg_type='model_para',
                             sample_client_num=-1,
                             filter_unseen_clients=True):
        """
        To broadcast the message to all clients or sampled clients
        """
        if filter_unseen_clients:
            # to filter out the unseen clients when sampling
            self.sampler.change_state(self.unseen_clients_id, 'unseen')

        if sample_client_num > 0:
            self.receiver = self.sampler.sample(size=sample_client_num)
        else:
            # broadcast to all clients
            self.receiver = list(self.comm_manager.neighbors.keys())
            if msg_type == 'model_para':
                self.sampler.change_state(self.receiver, 'working')

        if msg_type == 'model_para':
            # random sample start round and load saved global model
            self.start_round = np.random.randint(
                1, self._cfg.hpo.pfedhpo.target_fl_total_round)
            logger.info('==> Sampled start round: %d' % self.start_round)
            ckpt_path = os.path.join(
                self._cfg.hpo.working_folder,
                'temp_model_round_%d.pt' % self.start_round)
            if self.model_num > 1:
                raise NotImplementedError
            else:
                self.model.load_state_dict(torch.load(ckpt_path))

        model_para = self.model.state_dict()

        # generate hyper-params for all clients
        if not self.client2idx:
            client2idx = {}
            _all_clients = list(self.comm_manager.neighbors.keys())
            for i, k in zip(range(len(_all_clients)), _all_clients):
                client2idx[k] = i
            self.client2idx = client2idx

        if not self.discrete:
            var_max = 2.0
            var_min = 0.1
            var = var_max + (var_min - var_max) / (
                0.5 * self.total_round_num) * self.state
            if var < 0.1:
                var = 0.1
            self.HyperNet.var = var
            param_raw, self.logprob, self.entropy = self.HyperNet()
            xs = param_raw.detach().cpu().numpy()
        else:
            logits, self.enc_loss = self.HyperNet()
            # self.logprob = [None] * len(self.receiver)
            self.logprob = [None] * len(self.client2idx)
            self.p_idx = {}
            for k in self.pbounds.keys():
                # self.p_idx[k] = [None] * len(self.receiver)
                self.p_idx[k] = [None] * len(self.client2idx)

        # sample the hyper-parameter config specific to the clients
        self.sampled = False
        for rcv_idx in self.receiver:
            if not self.discrete:
                sampled_cfg = map_value_to_param(xs[self.client2idx[rcv_idx]],
                                                 self.pbounds, self._ss)
            else:
                client_logprob = 0.
                sampled_cfg = {}

                for i, (k, v) in zip(range(len(self.pbounds)),
                                     self.pbounds.items()):
                    probs = logits[i][self.client2idx[rcv_idx]]
                    m = torch.distributions.Categorical(probs)

                    idx = m.sample()
                    p = v[idx.item()]
                    if hasattr(self._ss[k], 'log') and self._ss[k].log:
                        p = 10**p
                    if 'int' in str(type(self._ss[k])).lower():
                        sampled_cfg[k] = int(p)
                    else:
                        sampled_cfg[k] = float(p)

                    log_prob = m.log_prob(idx)
                    client_logprob += log_prob
                    self.p_idx[k][self.client2idx[rcv_idx]] = torch.argmax(
                        probs)

                self.logprob[self.client2idx[rcv_idx]] = client_logprob / len(
                    self.pbounds)

            content = {'model_param': model_para, 'hyper_param': sampled_cfg}
            self.comm_manager.send(
                Message(msg_type=msg_type,
                        sender=self.ID,
                        receiver=[rcv_idx],
                        state=self.state,
                        content=content))
        if self._cfg.federate.online_aggr:
            try:
                for idx in range(self.model_num):
                    self.aggregators[idx].reset()
            except:
                pass

        if filter_unseen_clients:
            # restore the state of the unseen clients within sampler
            self.sampler.change_state(self.unseen_clients_id, 'seen')

    def callback_funcs_model_para(self, message: Message):
        round, sender, content = message.state, message.sender, message.content
        self.sampler.change_state(sender, 'idle')
        # For a new round
        if round not in self.msg_buffer['train'].keys():
            self.msg_buffer['train'][round] = dict()

        self.msg_buffer['train'][round][sender] = content

        if self._cfg.federate.online_aggr:
            try:
                self.aggregator.inc(tuple(content[0:2]))
            except:
                pass

        return self.check_and_move_on()

    def update_policy(self):
        key1 = 'Results_weighted_avg'
        key2 = 'val_acc'

        if 'twitter' in str(self._cfg.data.type).lower():
            anchor_res_start = \
                self.anchor_res['Results_raw']['test_acc'][self.start_round-1]
            res_end = \
                self.history_results['Results_weighted_avg']['test_acc'][-1]
        else:
            anchor_res_start = self.anchor_res[key1][key2][self.start_round -
                                                           1]
            res_end = self.history_results[key1][key2][-1]

        if not self.discrete:
            reward = np.maximum(0, res_end - anchor_res_start)
            losses = -reward * self.logprob

        else:
            reward = np.maximum(0, res_end - anchor_res_start) \
                     * anchor_res_start
            self.logprob = torch.stack(self.logprob, dim=-1)
            losses = F.relu(-reward * self.logprob * 100)

        self.opt.zero_grad()
        loss = losses.mean()
        loss.backward()
        nn.utils.clip_grad_norm_(self.opt_params, max_norm=10, norm_type=2)
        self.opt.step()
        if htcore is not None:
            htcore.mark_step()

    def check_and_move_on(self,
                          check_eval_result=False,
                          min_received_num=None):
        """
        To check the message_buffer, when enough messages are receiving,
        trigger some events (such as perform aggregation, evaluation,
        and move to the next training round)
        """
        if min_received_num is None:
            min_received_num = self._cfg.federate.sample_client_num
        assert min_received_num <= self.sample_client_num

        if check_eval_result:
            min_received_num = len(list(self.comm_manager.neighbors.keys()))

        move_on_flag = True  # To record whether moving to a new training
        # round or finishing the evaluation
        if self.check_buffer(self.state, min_received_num, check_eval_result):

            if not check_eval_result:  # in the training process
                mab_feedbacks = dict()
                # Get all the message
                train_msg_buffer = self.msg_buffer['train'][self.state]
                for model_idx in range(self.model_num):
                    model = self.models[model_idx]
                    aggregator = self.aggregators[model_idx]
                    msg_list = list()
                    for client_id in train_msg_buffer:
                        if self.model_num == 1:
                            msg_list.append(
                                tuple(train_msg_buffer[client_id][0:2]))
                        else:
                            train_data_size, model_para_multiple = \
                                train_msg_buffer[client_id][0:2]
                            msg_list.append((train_data_size,
                                             model_para_multiple[model_idx]))

                        # collect feedbacks for updating the policy
                        if model_idx == 0:
                            mab_feedbacks[client_id] = train_msg_buffer[
                                client_id][2]

                    # Trigger the monitor here (for training)
                    if 'dissim' in self._cfg.eval.monitoring:
                        from federatedscope.core.auxiliaries.utils import \
                            calc_blocal_dissim
                        # TODO: fix load_state_dict
                        B_val = calc_blocal_dissim(
                            model.load_state_dict(strict=False), msg_list)
                        formatted_eval_res = self._monitor.format_eval_res(
                            B_val, rnd=self.state, role='Server #')
                        logger.info(formatted_eval_res)

                    # Aggregate
                    agg_info = {
                        'client_feedback': msg_list,
                        'recover_fun': self.recover_fun
                    }
                    result = aggregator.aggregate(agg_info)
                    model.load_state_dict(result, strict=False)
                self.fb = mab_feedbacks

                self.state += 1
                #  Evaluate
                logger.info(
                    'Server: Starting evaluation at begin of round {:d}.'.
                    format(self.state))
                self.eval()

            else:  # in the evaluation process
                # Get all the message & aggregate
                logger.info('-' * 30)
                formatted_eval_res = self.merge_eval_results_from_all_clients()
                self.history_results = merge_dict(self.history_results,
                                                  formatted_eval_res)
                self.check_and_save()

                if self.state < self.total_round_num:
                    if len(self.history_results) > 0:
                        logger.info('=' * 10 + ' updating hypernet at round ' +
                                    str(self.state) + ' ' + '=' * 10)
                        self.update_policy()

                    # Move to next round of training
                    logger.info(
                        f'----------- Starting a new training round (Round '
                        f'#{self.state}) -------------')
                    logger.info(self._cfg.device)
                    # Clean the msg_buffer
                    self.msg_buffer['train'][self.state - 1].clear()

                    self.broadcast_model_para(
                        msg_type='model_para',
                        sample_client_num=self.sample_client_num)

        else:
            move_on_flag = False

        return move_on_flag

    def check_and_save(self):
        """
        To save the results and save model after each evaluation
        """
        # early stopping
        should_stop = False
        if "Results_weighted_avg" in self.history_results and \
                self._cfg.eval.best_res_update_round_wise_key in \
                self.history_results['Results_weighted_avg']:
            should_stop = self.early_stopper.track_and_check(
                self.history_results['Results_weighted_avg'][
                    self._cfg.eval.best_res_update_round_wise_key])
        elif "Results_avg" in self.history_results and \
                self._cfg.eval.best_res_update_round_wise_key in \
                self.history_results['Results_avg']:
            should_stop = self.early_stopper.track_and_check(
                self.history_results['Results_avg'][
                    self._cfg.eval.best_res_update_round_wise_key])
        else:
            should_stop = False

        if should_stop:
            self.state = self.total_round_num + 1

        _path = os.path.join(self._cfg.hpo.working_folder,
                             'hyperNet_encoding.pt')
        hyper_enc = {
            'hyperNet': self.HyperNet.state_dict(),
        }
        torch.save(hyper_enc, _path)

        if should_stop or self.state == self.total_round_num:
            logger.info('Server: Final evaluation is finished! Starting '
                        'merging results.')
            # last round
            self.save_best_results()

            if self._cfg.federate.save_to != '':
                # save the policy
                ckpt = dict()
                z_list = [z.tolist() for z in self._z]
                ckpt['z'] = z_list
                ckpt['store'] = self._store
                ckpt['stop'] = self._stop_exploration
                ckpt['global'] = self.trace('global').tolist()
                ckpt['refine'] = self.trace('refine').tolist()
                ckpt['entropy'] = self.trace('entropy').tolist()
                ckpt['mle'] = self.trace('mle').tolist()
                pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.federate.
                                                          save_to.rfind(
                                                              '.'
                                                          )] + "_pfedhpo.yaml"
                with open(pi_ckpt_path, 'w') as ops:
                    yaml.dump(ckpt, ops)

            if self.model_num > 1:
                model_para = [model.state_dict() for model in self.models]
            else:
                model_para = self.model.state_dict()
            self.comm_manager.send(
                Message(msg_type='finish',
                        sender=self.ID,
                        receiver=list(self.comm_manager.neighbors.keys()),
                        state=self.state,
                        content=model_para))

        if self.state == self.total_round_num:
            # break out the loop for distributed mode
            self.state += 1
