from copy import deepcopy
from collections import OrderedDict

import numpy as np
import scipy.stats as ss
from localglobal.bo.localbodiscrete import LocalBO1Discrete
from localglobal.bo.globalbodiscrete import GlobalBODiscrete
from localglobal.bo.localbomdiscrete import LocalBOMDiscrete
from localglobal.bo.localbo_utils import from_unit_cube, latin_hypercube, to_unit_cube, ordinal2onehot, onehot2ordinal,\
    random_sample_within_discrete_tr_ordinal
import torch
import logging

def order_stats(X):
    _, idx, cnt = np.unique(X, return_inverse=True, return_counts=True)
    obs = np.cumsum(cnt)  # Need to do it this way due to ties
    o_stats = obs[idx]
    return o_stats


def copula_standardize(X):
    X = np.nan_to_num(np.asarray(X))  # Replace inf by something large
    assert X.ndim == 1 and np.all(np.isfinite(X))
    o_stats = order_stats(X)
    quantile = np.true_divide(o_stats, len(X) + 1)
    X_ss = ss.norm.ppf(quantile)
    return X_ss


class TurboOptimizer:

    def __init__(self, config,
                 n_init: int = None,
                 wrap_discrete: bool = True,
                 guided_restart: bool = True,
                 global_bo: bool = False,
                 **kwargs):
        """Build wrapper class to use an optimizer in benchmark.

        Parameters
        ----------
        config: list. e.g. [2, 3, 4, 5] -- denotes there are 4 categorical variables, with numbers of categories
            being 2, 3, 4, and 5 respectively.
        guided_restart: whether to fit an auxiliary GP over the best points encountered in all previous restarts, and
            sample the points with maximum variance for the next restart.
        global_bo: whether to use the global version of the discrete GP without local modelling
        """

        # Maps the input order.
        self.config = config.astype(int)
        self.true_dim = len(config)
        self.kwargs = kwargs
        # Number of one hot dimensions
        self.n_onehot = int(np.sum(config))
        # One-hot bounds
        self.lb = np.zeros(self.n_onehot)
        self.ub = np.ones(self.n_onehot)
        self.dim = len(self.lb)
        # True dim is simply th`e number of parameters (do not care about one-hot encoding etc).
        self.max_evals = np.iinfo(np.int32).max  # NOTE: Largest possible int
        self.batch_size = None
        self.history = []
        self.wrap_discrete = wrap_discrete
        self.cat_dims = self.get_dim_info(config)

        self.global_bo = global_bo

        if self.global_bo:
            self.turbo = GlobalBODiscrete(
                dim=self.true_dim,
                n_init=n_init if n_init is not None else 2 * self.true_dim + 1,
                max_evals=self.max_evals,
                batch_size=1,  # We need to update this later
                verbose=False,
                config=self.config,
                **kwargs
            )
        else:
            self.turbo = LocalBO1Discrete(
                dim=self.true_dim,
                n_init=n_init if n_init is not None else 2 * self.true_dim + 1,
                max_evals=self.max_evals,
                batch_size=1,  # We need to update this later
                verbose=False,
                config=self.config,
                **kwargs
            )

        # Our modification: define an auxiliary GP
        self.guided_restart = guided_restart
        # keep track of the best X and fX in each restart
        self.best_X_each_restart, self.best_fX_each_restart = None, None
        self.auxiliary_gp = None

    def restart(self):
        from localglobal.bo.localbo_utils import train_gp

        if self.guided_restart and len(self.turbo._fX):
            # if this option is enabled, we fit the best (X, fX) from each trust region to a GP, then we generate a
            # list of points with the highest predictive uncertainty. These points will be the initialising points to
            # the next trust region restart. This modification encourages the next GP to explore points unvisited by
            # previous explorations.

            # batch_size is None suggests that we are just starting, so no point can be added to the auxiliary GP

            # Get the best index
            best_idx = self.turbo._fX.argmin()
            # Obtain the best X and fX within each restart (bo._fX and bo._X get erased at each restart,
            # but bo.X and bo.fX always store the full history
            if self.best_fX_each_restart is None:
                self.best_fX_each_restart = deepcopy(self.turbo._fX[best_idx])
                self.best_X_each_restart = deepcopy(self.turbo._X[best_idx])
            else:
                self.best_fX_each_restart = np.vstack((self.best_fX_each_restart, deepcopy(self.turbo._fX[best_idx])))
                self.best_X_each_restart = np.vstack((self.best_X_each_restart, deepcopy(self.turbo._X[best_idx])))

            X_tr_torch = torch.tensor(self.best_X_each_restart, dtype=torch.float32).reshape(-1, self.true_dim)
            fX_tr_torch = torch.tensor(self.best_fX_each_restart, dtype=torch.float32).view(-1)
            # Train the auxiliary
            self.auxiliary_gp = train_gp(X_tr_torch, fX_tr_torch, False, 300, )
            # Generate random points in a Thompson-style sampling
            X_init = latin_hypercube(self.turbo.n_cand, self.dim)
            X_init = from_unit_cube(X_init, self.lb, self.ub)
            if self.wrap_discrete:
                X_init = self.warp_discrete(X_init, )
            X_init = onehot2ordinal(X_init, self.cat_dims)
            with torch.no_grad():
                self.auxiliary_gp.eval()
                X_init_torch = torch.tensor(X_init, dtype=torch.float32)
                # LCB-sampling
                y_cand_mean, y_cand_var = self.auxiliary_gp(
                    X_init_torch).mean.cpu().detach().numpy(), self.auxiliary_gp(
                    X_init_torch).variance.cpu().detach().numpy()
                y_cand = y_cand_mean - 1.96 * np.sqrt(y_cand_var)

            self.X_init = np.ones((self.turbo.n_init, self.true_dim))
            indbest = np.argmin(y_cand)
            # The initial trust region centre for the new restart
            centre = deepcopy(X_init[indbest, :])
            # The centre is the first point to be evaluated
            self.X_init[0, :] = deepcopy(centre)
            for i in range(1, self.turbo.n_init):
                # Randomly sample within the initial trust region length around the centre
                self.X_init[i, :] = deepcopy(
                    random_sample_within_discrete_tr_ordinal(centre, self.turbo.length_init_discrete, self.config))
            self.turbo._restart()
            self.turbo._X = np.zeros((0, self.turbo.dim))
            self.turbo._fX = np.zeros((0, 1))
            del X_tr_torch, fX_tr_torch, X_init_torch

        else:
            # If guided restart is not enabled, simply sample a number of points equal to the number of evaluated
            self.turbo._restart()
            self.turbo._X = np.zeros((0, self.turbo.dim))
            self.turbo._fX = np.zeros((0, 1))
            X_init = latin_hypercube(self.turbo.n_init, self.dim)
            self.X_init = from_unit_cube(X_init, self.lb, self.ub)
            if self.wrap_discrete:
                self.X_init = self.warp_discrete(self.X_init, )
            self.X_init = onehot2ordinal(self.X_init, self.cat_dims)

    def suggest(self, n_suggestions=1):
        if self.batch_size is None:  # Remember the batch size on the first call to suggest
            self.batch_size = n_suggestions
            self.turbo.batch_size = n_suggestions
            # self.bo.failtol = np.ceil(np.max([4.0 / self.batch_size, self.dim / self.batch_size]))
            self.turbo.n_init = max([self.turbo.n_init, self.batch_size])
            self.restart()

        X_next = np.zeros((n_suggestions, self.true_dim))

        # Pick from the initial points
        n_init = min(len(self.X_init), n_suggestions)
        if n_init > 0:
            X_next[:n_init] = deepcopy(self.X_init[:n_init, :])
            self.X_init = self.X_init[n_init:, :]  # Remove these pending points

        # Get remaining points from TuRBO
        n_adapt = n_suggestions - n_init
        if n_adapt > 0:
            if len(self.turbo._X) > 0:  # Use random points if we can't fit a GP
                X = deepcopy(self.turbo._X)
                # X = to_unit_cube(deepcopy(self.bo._X), self.lb, self.ub)
                # if self.wrap_discrete:
                #     X = self.warp_discrete(X, )
                fX = copula_standardize(deepcopy(self.turbo._fX).ravel())  # Use Copula
                # todo: temporarily commented out. Will be reintroduced when we test on the mixed-search space search.
                # X_cand, y_cand, _ = self.bo._create_candidates(
                #     X, fX, length=self.bo.length, n_training_steps=100, hypers={}
                # )
                # if self.wrap_discrete:
                #     X_cand = self.warp_discrete(X_cand,)
                # X_next[-n_adapt:, :] = self.bo._select_candidates(X_cand, y_cand)[:n_adapt, :]
                # X_next[-n_adapt:, :] = from_unit_cube(X_next[-n_adapt:, :], self.lb, self.ub)
                #
                X_next[-n_adapt:, :] = self.turbo._create_and_select_candidates(X, fX,
                                                                                length=self.turbo.length_discrete,
                                                                                n_training_steps=300,
                                                                                hypers={})[-n_adapt:, :]

        # Unwarp the suggestions
        # suggestions = onehot2ordinal(X_next, self.cat_dims)
        suggestions = X_next
        return suggestions

    def observe(self, X, y, override_n_evals=False):
        """Send an observation of a suggestion back to the optimizer.

        Parameters
        ----------
        X : list of dict-like
            Places where the objective function has already been evaluated.
            Each suggestion is a dictionary where each key corresponds to a
            parameter being optimized.
        y : array-like, shape (n,)
            Corresponding values where objective has been evaluated
        """
        assert len(X) == len(y)
        # XX = torch.cat([ordinal2onehot(x, self.n_categories) for x in X]).reshape(len(X), -1)
        XX = X
        yy = np.array(y)[:, None]
        # if self.wrap_discrete:
        #     XX = self.warp_discrete(XX, )

        if len(self.turbo._fX) >= self.turbo.n_init:
            self.turbo._adjust_length(yy)

        if not override_n_evals:
            self.turbo.n_evals += self.batch_size
        else:
            self.turbo.n_evals += len(X)

        self.turbo._X = np.vstack((self.turbo._X, deepcopy(XX)))
        self.turbo._fX = np.vstack((self.turbo._fX, deepcopy(yy.reshape(-1, 1))))
        self.turbo.X = np.vstack((self.turbo.X, deepcopy(XX)))
        self.turbo.fX = np.vstack((self.turbo.fX, deepcopy(yy.reshape(-1, 1))))

        # Check for a restart
        if self.turbo.length <= self.turbo.length_min or self.turbo.length_discrete <= self.turbo.length_min_discrete:
            self.restart()

    def warp_discrete(self, X, ):

        X_ = np.copy(X)
        # Process the integer dimensions
        if self.cat_dims is not None:
            for categorical_groups in self.cat_dims:
                max_col = np.argmax(X[:, categorical_groups], axis=1)
                X_[:, categorical_groups] = 0
                for idx, g in enumerate(max_col):
                    X_[idx, categorical_groups[g]] = 1
        return X_

    def get_dim_info(self, n_categories):
        dim_info = []
        offset = 0
        for i, cat in enumerate(n_categories):
            dim_info.append(list(range(offset, offset + cat)))
            offset += cat
        return dim_info


class TurboMOptimizer(TurboOptimizer):
    """The multiple trust region variant of TurBO."""

    def __init__(self, n_categories: int,
                 n_trust_regions: int,
                 n_init: int = None,
                 warp_discrete: bool = True,
                 guided_restart: bool = True,
                 **kwargs):
        assert n_trust_regions > 1, 'n_trust_region should be more than 1. if a single trust region is preferred,' \
                                    'please use the TurboOptimizer.'
        self.n_trust_regions = n_trust_regions
        self.X_init = None
        # the n_init supplied above is the **global** number of initialising points allowed. Need to divide this by
        # the number of trust regions
        if n_init is None:
            n_init = 2 * self.true_dim + 1
        n_init = max(1, int(n_init / n_trust_regions))
        if n_init == 1:
            logging.warning('The number of initialising points per trust region is at the minimum value of 1. Consider'
                            'reduce the number of trust regions or increase the number of initialising points!')

        super(TurboMOptimizer, self).__init__(n_categories, n_init, warp_discrete, guided_restart, **kwargs)
        self.turbo = LocalBOMDiscrete(
            f=None,
            dim=self.true_dim,
            true_dim=self.true_dim,
            n_init=n_init,
            max_evals=self.max_evals,
            n_cats=self.config,
            verbose=False,
            n_trust_regions=self.n_trust_regions,
            batch_size=1,  # Will be updated later.
            **kwargs
        )

    def restart(self, idx=None):
        """Restart one, or all, of the trust region(s).
        idx: if specified, restart that trust region. Otherwise restart all turst region"""
        from localglobal.bo.localbo_utils import train_gp

        if idx is None:
            idx_selected = list(range(self.n_trust_regions))
        else:
            idx_selected = [idx]
        for i in idx_selected:
            idx_i = self.turbo._idx[:, 0] == i

            # print(len(self.bo._X), len(self.bo._fX))
            if self.guided_restart and len(self.turbo._X):
                # Select the points that belong to the current trust region
                X_i, fX_i = self.turbo._X[idx_i], self.turbo._fX[idx_i]
                best_idx = fX_i.argmin()

                if self.best_fX_each_restart is None:
                    self.best_X_each_restart = deepcopy(X_i[best_idx])
                    self.best_fX_each_restart = deepcopy(fX_i[best_idx])
                else:
                    self.best_X_each_restart = np.vstack((self.best_X_each_restart, deepcopy(X_i[best_idx])))
                    self.best_fX_each_restart = np.vstack((self.best_fX_each_restart, deepcopy(fX_i[best_idx])))

                X_tr_torch = torch.tensor(self.best_X_each_restart, dtype=torch.float32).reshape(-1, self.true_dim)
                fX_tr_torch = torch.tensor(self.best_fX_each_restart, dtype=torch.float32).view(-1)
                # Train the auxiliary
                self.auxiliary_gp = train_gp(X_tr_torch, fX_tr_torch, False, 300, )
                # Generate random points in a Thompson-style sampling
                X_cand = latin_hypercube(self.turbo.n_cand, self.dim)
                X_cand = from_unit_cube(X_cand, self.lb, self.ub)
                if self.wrap_discrete:
                    X_cand = self.warp_discrete(X_cand, )
                X_cand = onehot2ordinal(X_cand, self.cat_dims)
                with torch.no_grad():
                    self.auxiliary_gp.eval()
                    X_cand_torch = torch.tensor(X_cand, dtype=torch.float32)
                    # LCB sampling
                    y_cand_mean, y_cand_var = self.auxiliary_gp(X_cand_torch).mean.cpu().detach().numpy(), self.auxiliary_gp(X_cand_torch).variance.cpu().detach().numpy()
                    y_cand = y_cand_mean - 1.96 * np.sqrt(y_cand_var)
                X_init = np.ones((self.turbo.n_init, self.true_dim))
                indbest = np.argmin(y_cand)
                centre = deepcopy(X_cand[indbest, :])
                for j in range(1, self.turbo.n_init):
                    X_init[j, :] = deepcopy(
                        random_sample_within_discrete_tr_ordinal(centre, self.turbo.length_init_discrete, self.config)
                    )
                    # indbest = np.argmax(y_cand)
                    # X_init[j, :] = deepcopy(X_cand[indbest, :])
                    # y_cand[indbest] = -np.inf
            else:
                # Add the randomly initialising points to the trust region
                X_init = latin_hypercube(self.turbo.n_init, self.dim)
                X_init = from_unit_cube(X_init, self.lb, self.ub)
                if self.wrap_discrete:
                    X_init = self.warp_discrete(X_init, )
                X_init = onehot2ordinal(X_init, self.cat_dims)

            self.turbo.length[i] = self.turbo.length_init
            self.turbo.length_discrete[i] = self.turbo.length_init_discrete
            self.turbo.succcount[i] = 0
            self.turbo.failcount[i] = 0
            # Remove (existing) points from trust region
            self.turbo._idx[idx_i, 0] = -1
            if self.X_init is None:
                self.X_init = deepcopy(X_init)
                self.X_init_idx = i * np.ones((len(X_init), 1))
            else:
                self.X_init = np.vstack((self.X_init, deepcopy(X_init)))
                self.X_init_idx = np.vstack((self.X_init_idx, i * np.ones((len(X_init), 1))))
            # Add to the local history

    def suggest(self, n_suggestions=1):
        if self.batch_size is None:  # Remember the batch size on the first call to suggest
            self.batch_size = n_suggestions
            self.turbo.batch_size = n_suggestions
            self.turbo.n_init = max([self.turbo.n_init, self.batch_size])
            self.restart()

        X_next = np.zeros((n_suggestions, self.true_dim))
        X_next_idx = np.zeros((n_suggestions, 1))

        # Pick from the initial points
        n_init = min(len(self.X_init), n_suggestions)
        if n_init > 0:
            X_next[:n_init] = deepcopy(self.X_init[:n_init, :])
            X_next_idx[:n_init] = deepcopy(self.X_init_idx[:n_init, :])
            self.X_init = self.X_init[n_init:, :]
            self.X_init_idx = self.X_init_idx[n_init:, :]  # Remove these pending points
            # Store the points in the bo instance
            if not len(self.turbo.X):
                self.turbo.X = deepcopy(X_next[:n_init])
            else:
                self.turbo.X = np.vstack((self.turbo.X, deepcopy(X_next[:n_init])))

            if not len(self.turbo._X):
                self.turbo._X = deepcopy(X_next[:n_init])
            else:
                self.turbo._X = np.vstack((self.turbo._X, deepcopy(X_next[:n_init])))

            self.turbo._idx = np.vstack((self.turbo._idx, deepcopy(X_next_idx[:n_init])))

        # Get remaining points from TuRBO
        n_adapt = n_suggestions - n_init
        if n_adapt > 0:
            if len(self.turbo._X) > 0:  # Use random points if we can't fit a GP
                X = deepcopy(self.turbo._X)
                fX = copula_standardize(deepcopy(self.turbo._fX).ravel())  # Use Copula
                # Get the suggested points from **each trust region**
                x_center = X[fX.argmin().item(), :]
                res = self.turbo._create_and_select_candidates(X, fX,
                                                               length=self.turbo.length_discrete,
                                                               n_training_steps=100,
                                                               hypers={},
                                                               return_acq=True,
                                                               incumbent=x_center)
                X_next_candidates, tr_candidates, acq_next_candidates = res[0][-n_adapt:, :], res[1][-n_adapt:, :], res[
                                                                                                                        2][
                                                                                                                    -n_adapt:,
                                                                                                                    :],

                # Select the top-k best points from the bests from all trust region.
                top_idices = np.argpartition(acq_next_candidates, -self.batch_size)[-self.batch_size:].ravel()
                X_next[-n_adapt, :] = X_next_candidates[top_idices, :]
                # Different from TurBO1, we have to now update the list of points in bo
                self.turbo._X = np.vstack((self.turbo._X, deepcopy(X_next)))
                self.turbo.X = np.vstack((self.turbo.X, deepcopy((X_next))))
                self.turbo._idx = np.vstack((self.turbo._idx, deepcopy(tr_candidates[top_idices, :])))
                # X_next[-n_adapt:, :], tr_regions = \

        suggestions = X_next
        return suggestions

    def observe(self, X, y):
        assert len(X) == len(y)
        assert np.all(X == self.turbo.X[-len(X):])
        yy = np.array(y)[:, None]
        # For each of X, find the index of the
        for i, x in enumerate(X):
            idx = int(self.turbo._idx[-len(X) + i])
            # now we update function value (we previously updated X in suggest)
            self.turbo.fX = np.vstack((self.turbo.fX, deepcopy(yy)))
            self.turbo._fX = np.vstack((self.turbo._fX, deepcopy(yy)))
            # Adjust the length for the particular trust region **if we are not restarting**
            if len(self.X_init_idx == idx) == 0:
                self.turbo._adjust_length(yy, idx)
            # Restart if necessary
            if self.turbo.length[idx] <= self.turbo.length_min or self.turbo.length_discrete[
                idx] <= self.turbo.length_min_discrete:
                self.restart(idx)
