import copy
import os
import logging
from itertools import product
import pickle
import yaml

import numpy as np
from numpy.linalg import norm
from scipy.special import logsumexp
import GPy

from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.core.auxiliaries.utils import merge_dict
from federatedscope.autotune.fts.utils import *
from federatedscope.autotune.utils import parse_search_space

logger = logging.getLogger(__name__)


class FTSServer(Server):
    def __init__(self,
                 ID=-1,
                 state=0,
                 config=None,
                 data=None,
                 model=None,
                 client_num=5,
                 total_round_num=10,
                 device='cpu',
                 strategy=None,
                 **kwargs):
        super(FTSServer,
              self).__init__(ID, state, config, data, model, client_num,
                             total_round_num, device, strategy, **kwargs)
        assert self.sample_client_num == self._cfg.federate.client_num

        self.util_ts = UtilityFunction(kind="ts")
        self.M = self._cfg.hpo.fts.M

        # server file paths
        self.all_lcoal_init_path = os.path.join(self._cfg.hpo.working_folder,
                                                "all_localBO_inits.pkl")
        self.all_local_info_path = os.path.join(self._cfg.hpo.working_folder,
                                                "all_localBO_infos.pkl")
        self.rand_feat_path = os.path.join(
            self._cfg.hpo.working_folder,
            "rand_feat_M_" + str(self.M) + ".pkl")

        # prepare search space and bounds
        self._ss = parse_search_space(self._cfg.hpo.fts.ss)
        self.dim = len(self._ss)
        self.bounds = np.asarray([(0., 1.) for _ in self._ss])
        self.pbounds = {}
        for k, v in self._ss.items():
            if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
                raise ValueError("Unsupported hyper type {}".format(type(v)))
            else:
                if v.log:
                    l, u = np.log10(v.lower), np.log10(v.upper)
                else:
                    l, u = v.lower, v.upper
                self.pbounds[k] = (l, u)

        # distribution used to sample GP models
        pt = 1 - 1 / (np.arange(self._cfg.hpo.fts.fed_bo_max_iter + 5) +
                      1)**2.0
        pt[0] = pt[1]
        self.pt = pt
        self.num_other_clients = self._cfg.federate.client_num - 1
        self.ws = np.ones(self.num_other_clients) / self.num_other_clients

        # records for all GP models
        N = self._cfg.federate.client_num + 1
        self.x_max = [None for _ in range(N)]
        self.y_max = [None for _ in range(N)]
        self.X = [None for _ in range(N)]
        self.Y = [None for _ in range(N)]
        self.incumbent = [None for _ in range(N)]
        self.gp = [None for _ in range(N)]
        self.gp_params = [None for _ in range(N)]
        self.initialized = [False for _ in range(N)]
        self.res = [{
            'max_value': None,
            'max_param': None,
            'all_values': [],
            'all_params': [],
        } for _ in range(N)]
        self.res_paths = [
            os.path.join(self._cfg.hpo.working_folder,
                         "result_params_%d.pkl" % cid) for cid in range(N)
        ]

        # load or generate agent_info, agent_init, and rand_feat.
        # if files already exit, load from saved files;
        # else require the clients in the first round
        if os.path.exists(self.all_local_info_path) and \
                self._cfg.hpo.fts.allow_load_existing_info:
            logger.info('Using existing rand_feat, agent_infos, '
                        'and agent_inits')
            self.require_agent_infos = False
            self.state = 1
            self.random_feats = pickle.load(open(self.rand_feat_path, 'rb'))
            self.all_agent_info = pickle.load(
                open(self.all_local_info_path, 'rb'))
            self.all_agent_init = pickle.load(
                open(self.all_lcoal_init_path, 'rb'))
        else:
            self.require_agent_infos = True
            self.random_feats = self._generate_shared_rand_feats()
            self.all_agent_info = {}
            self.all_agent_init = {}

        # point out target clients need to be optimized by FTS
        if not config.hpo.fts.target_clients:
            self.target_clients = list(
                range(1, self._cfg.federate.client_num + 1))
        else:
            self.target_clients = list(config.hpo.fts.target_clients)

    def _generate_shared_rand_feats(self):
        # generate shared random features
        ls = self._cfg.hpo.fts.ls
        v_kernel = self._cfg.hpo.fts.v_kernel
        obs_noise = self._cfg.hpo.fts.obs_noise
        M = self.M

        s = np.random.multivariate_normal(np.zeros(self.dim),
                                          1 / ls**2 * np.identity(self.dim), M)
        b = np.random.uniform(0, 2 * np.pi, M)
        random_features = {
            "M": M,
            "length_scale": ls,
            "s": s,
            "b": b,
            "obs_noise": obs_noise,
            "v_kernel": v_kernel
        }
        pickle.dump(random_features, open(self.rand_feat_path, "wb"))

        return random_features

    def broadcast_model_para(self,
                             msg_type='model_para',
                             sample_client_num=-1,
                             filter_unseen_clients=True):
        """
        To broadcast the message to all clients or sampled clients
        """
        if filter_unseen_clients:
            # to filter out the unseen clients when sampling
            self.sampler.change_state(self.unseen_clients_id, 'unseen')

        if self.require_agent_infos:
            # broadcast to all clients
            receiver = list(self.comm_manager.neighbors.keys())
            if msg_type == 'model_para':
                self.sampler.change_state(receiver, 'working')
        else:
            receiver = list(self.target_clients)

        for rcv_idx in receiver:
            if self.require_agent_infos:
                content = {
                    'require_agent_infos': True,
                    'random_feats': self.random_feats,
                }

            else:
                # initialize gp models and init points
                if not self.initialized[rcv_idx]:
                    init = self.all_agent_init[rcv_idx]
                    self.X[rcv_idx] = init['X']
                    self.Y[rcv_idx] = init['Y']
                    self.incumbent[rcv_idx] = np.max(self.Y[rcv_idx])
                    logger.info("Using pre-existing initializations "
                                "for client {} with {} points".format(
                                    rcv_idx, len(self.Y[rcv_idx])))

                    y_max = np.max(self.Y[rcv_idx])
                    self.y_max[rcv_idx] = y_max

                    ur = unique_rows(self.X[rcv_idx])
                    self.gp[rcv_idx] = GPy.models.GPRegression(
                        self.X[rcv_idx][ur],
                        self.Y[rcv_idx][ur].reshape(-1, 1),
                        GPy.kern.RBF(input_dim=self.X[rcv_idx].shape[1],
                                     lengthscale=self._cfg.hpo.fts.ls,
                                     variance=self._cfg.hpo.fts.var,
                                     ARD=False))
                    self.gp[rcv_idx]["Gaussian_noise.variance"][0] = \
                        self._cfg.hpo.fts.g_var
                    self._opt_gp(rcv_idx)
                    self.initialized[rcv_idx] = True

                # sample hyper from this client's GP or others' GP
                info_ts = copy.deepcopy(self.all_agent_info)
                del info_ts[rcv_idx]
                info_ts = list(info_ts.values())

                def _try_sample_try(func):
                    _loop = True
                    while _loop:
                        try:
                            x_max, all_ucb = func(rcv_idx, self.y_max[rcv_idx],
                                                  self.state, info_ts)
                            _loop = False
                        except:
                            _loop = True
                    return x_max, all_ucb

                if np.random.random() < self.pt[self.state - 1]:
                    x_max, all_ucb = _try_sample_try(self._sample_from_this)
                else:
                    x_max, all_ucb = _try_sample_try(self._sample_from_others)
                self.x_max[rcv_idx] = x_max

                content = {
                    'require_agent_infos': False,
                    'x_max': x_max,
                }

            self.comm_manager.send(
                Message(msg_type=msg_type,
                        sender=self.ID,
                        receiver=[rcv_idx],
                        state=self.state,
                        content=content))

        if filter_unseen_clients:
            # restore the state of the unseen clients within sampler
            self.sampler.change_state(self.unseen_clients_id, 'seen')

    def _sample_from_this(self, client, y_max, iteration, info_ts):
        M_target = self._cfg.hpo.fts.M_target

        ls_target = self.gp[client]["rbf.lengthscale"][0]
        v_kernel = self.gp[client]["rbf.variance"][0]
        obs_noise = self.gp[client]["Gaussian_noise.variance"][0]

        s = np.random.multivariate_normal(
            np.zeros(self.dim), 1 / (ls_target**2) * np.identity(self.dim),
            M_target)
        b = np.random.uniform(0, 2 * np.pi, M_target)
        random_features_target = {
            "M": M_target,
            "length_scale": ls_target,
            "s": s,
            "b": b,
            "obs_noise": obs_noise,
            "v_kernel": v_kernel
        }

        Phi = np.zeros((self.X[client].shape[0], M_target))
        for i, x in enumerate(self.X[client]):
            x = np.squeeze(x).reshape(1, -1)
            features = np.sqrt(
                2 / M_target) * np.cos(np.squeeze(np.dot(x, s.T)) + b)
            features = features / np.sqrt(np.inner(features, features))
            features = np.sqrt(v_kernel) * features
            Phi[i, :] = features

        Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M_target)
        Sigma_t_inv = np.linalg.inv(Sigma_t)
        nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T),
                      self.Y[client].reshape(-1, 1))
        w_sample = np.random.multivariate_normal(np.squeeze(nu_t),
                                                 obs_noise * Sigma_t_inv, 1)

        x_max, all_ucb = acq_max(ac=self.util_ts.utility,
                                 gp=self.gp[client],
                                 M=M_target,
                                 N=self.num_other_clients,
                                 gp_samples=None,
                                 random_features=random_features_target,
                                 info_ts=info_ts,
                                 pt=self.pt,
                                 ws=self.ws,
                                 use_target_label=True,
                                 w_sample=w_sample,
                                 y_max=y_max,
                                 bounds=self.bounds,
                                 iteration=iteration)
        return x_max, all_ucb

    def _sample_from_others(self, client, y_max, iteration, info_ts):
        agent_ind = np.arange(self.num_other_clients)
        random_agent_n = np.random.choice(agent_ind, 1, p=self.ws)[0]
        w_sample = info_ts[random_agent_n]
        x_max, all_ucb = acq_max(ac=self.util_ts.utility,
                                 gp=self.gp[client],
                                 M=self.M,
                                 N=self.num_other_clients,
                                 gp_samples=None,
                                 random_features=self.random_feats,
                                 info_ts=info_ts,
                                 pt=self.pt,
                                 ws=self.ws,
                                 use_target_label=False,
                                 w_sample=w_sample,
                                 y_max=y_max,
                                 bounds=self.bounds,
                                 iteration=iteration)
        return x_max, all_ucb

    def _opt_gp(self, client):
        self.gp[client].optimize_restarts(num_restarts=10,
                                          messages=False,
                                          verbose=False)
        self.gp_params[client] = self.gp[client].parameters
        # print("---Optimized Hyper of Client %d : " % client, self.gp[client])

    def callback_funcs_model_para(self, message: Message):
        round, sender, content = message.state, message.sender, message.content
        self.sampler.change_state(sender, 'idle')
        # For a new round
        if round not in self.msg_buffer['train'].keys():
            self.msg_buffer['train'][round] = dict()

        self.msg_buffer['train'][round][sender] = content

        return self.check_and_move_on()

    def check_and_move_on(self,
                          check_eval_result=False,
                          min_received_num=None):
        """
        To check the message_buffer. When enough messages are receiving,
        some events (such as perform aggregation, evaluation, and move to
        the next training round) would be triggered.

        Arguments:
            check_eval_result (bool): If True, check the message buffer for
            evaluation; and check the message buffer for training otherwise.
        """
        if min_received_num is None or check_eval_result:
            min_received_num = len(self.target_clients)
        if self.require_agent_infos:
            min_received_num = self._cfg.federate.client_num

        move_on_flag = True

        # When enough messages are receiving
        if self.check_buffer(self.state, min_received_num, check_eval_result):

            if not check_eval_result:
                # The first round is to collect clients' infomation,
                # receive agent_info
                if self.require_agent_infos:
                    for _client, _content in \
                            self.msg_buffer['train'][self.state].items():
                        assert _content['is_required_agent_info']
                        self.all_agent_info[_client] = _content['agent_info']
                        self.all_agent_init[_client] = _content['agent_init']
                    pickle.dump(self.all_agent_info,
                                open(self.all_local_info_path, "wb"))
                    pickle.dump(self.all_agent_init,
                                open(self.all_lcoal_init_path, "wb"))
                    self.require_agent_infos = False

                # Other rounds are to update GP models, receive performance
                else:
                    for _client, _content in \
                            self.msg_buffer['train'][self.state].items():
                        curr_y = _content['curr_y']
                        self.Y[_client] = np.append(self.Y[_client], curr_y)
                        self.X[_client] = np.vstack(
                            (self.X[_client], self.x_max[_client].reshape(
                                (1, -1))))
                        if self.Y[_client][-1] > self.y_max[_client]:
                            self.y_max[_client] = self.Y[_client][-1]
                            self.incumbent[_client] = self.Y[_client][-1]
                        ur = unique_rows(self.X[_client])
                        self.gp[_client].set_XY(X=self.X[_client][ur],
                                                Y=self.Y[_client][ur].reshape(
                                                    -1, 1))

                        _schedule = self._cfg.hpo.fts.gp_opt_schedule
                        if self.state >= _schedule \
                                and self.state % _schedule == 0:
                            self._opt_gp(_client)

                        x_max_param = self.X[_client][self.Y[_client].argmax()]
                        hyper_param = x2conf(x_max_param, self.pbounds,
                                             self._ss)
                        self.res[_client]['max_param'] = hyper_param
                        self.res[_client]['max_value'] = self.Y[_client].max()
                        self.res[_client]['all_values'].append(
                            self.Y[_client][-1].tolist())
                        self.res[_client]['all_params'].append(
                            self.X[_client][-1].tolist())
                        pickle.dump(self.res[_client],
                                    open(self.res_paths[_client], 'wb'))

                self.state += 1
                if self.state <= self._cfg.hpo.fts.fed_bo_max_iter:
                    # Move to next round of training
                    logger.info(f'----------- GP optimizing iteration (Round '
                                f'#{self.state}) -------------')
                    # Clean the msg_buffer
                    self.msg_buffer['train'][self.state - 1].clear()
                    self.broadcast_model_para()
                else:
                    # Final Evaluate
                    logger.info('Server: Training is finished! Starting '
                                'evaluation.')
                    self.eval()

            else:
                _results = {}
                for _client, _content in \
                        self.msg_buffer['eval'][self.state].items():
                    _results[_client] = _content

                self.history_results = merge_dict(self.history_results,
                                                  _results)
                if self.state > self._cfg.hpo.fts.fed_bo_max_iter:
                    self.check_and_save()
        else:
            move_on_flag = False

        return move_on_flag

    def check_and_save(self):
        """
        To save the results and save model after each evaluation
        """
        # early stopping
        should_stop = False

        formatted_best_res = self._monitor.format_eval_res(
            results=self.history_results,
            rnd="Final",
            role='Server #',
            forms=["raw"],
            return_raw=True)
        logger.info('*' * 50)
        logger.info(formatted_best_res)
        self._monitor.save_formatted_results(formatted_best_res)

        if should_stop:
            self.state = self.total_round_num + 1

        logger.info('*' * 50)
        if self.state == self.total_round_num:
            # break out the loop for distributed mode
            self.state += 1
