# Modified from original TurBO code by Xingchen Wan <xwan@robots.ox.ac.uk>. The
# original copyright declaration below:

# This is for ablation -- the globalBO discrete is TurBOdiscrete but without local modelling
import math
import sys
from copy import deepcopy

import gpytorch
import numpy as np
import torch
from torch.quasirandom import SobolEngine

from localglobal.bo.localbo_utils import train_gp
from localglobal.bo.localbo_utils import from_unit_cube, latin_hypercube, to_unit_cube
from localglobal.bo.localbo_utils import onehot2ordinal
from localglobal.bo.localbo_utils import random_sample_within_discrete_tr_ordinal


class GlobalBODiscrete:
    """The TuRBO-1 algorithm.

    Parameters
    ----------
    f : function handle
    lb : Lower variable bounds, numpy.array, shape (d,).
    ub : Upper variable bounds, numpy.array, shape (d,).
    n_init : Number of initial points (2*dim is recommended), int.
    max_evals : Total evaluation budget, int.
    batch_size : Number of points in each batch, int.
    verbose : If you want to print information about the optimization progress, bool.
    use_ard : If you want to use ARD for the GP kernel.
    max_cholesky_size : Largest number of training points where we use Cholesky, int
    n_training_steps : Number of training steps for learning the GP hypers, int
    min_cuda : We use float64 on the CPU if we have this or fewer datapoints
    device : Device to use for GP fitting ("cpu" or "cuda")
    dtype : Dtype to use for GP fitting ("float32" or "float64")

    Data types that require special treatments
    cat_dims: list of lists. e.g. [[1, 2], [3, 4, 5]], which denotes that indices 1,2,3,4,5 are categorical, and [1, 2]
        belong to the same variable (a categorical variable with 2 possible values) and [3, 4, 5] belong to another,
        with 3 possible values.
    int_dims: list. [2, 3, 4]. Denotes the indices of the dimensions that are of integer types

    true_dim: The actual dimension of the problem. When there is no categorical variables, this value would be the same
        as the dimensionality inferred from the data. When there are categorical variable(s), due to the one-hot
        transformation. If not supplied, the dimension inferred from the data will be used.

    Example usage:
        turbo1 = Turbo1(f=f, lb=lb, ub=ub, n_init=n_init, max_evals=max_evals)
        turbo1.optimize()  # Run optimization
        X, fX = turbo1.X, turbo1.fX  # Evaluated points
    """

    def __init__(
            self,
            dim,
            n_init,
            max_evals,
            config,
            batch_size=1,
            verbose=True,
            use_ard=True,
            max_cholesky_size=2000,
            n_training_steps=50,
            min_cuda=1024,
            device="cpu",
            dtype="float32",
            acq='thompson',
            kernel_type='type2',
            **kwargs
    ):

        # Very basic input checks
        # assert lb.ndim == 1 and ub.ndim == 1
        # assert len(lb) == len(ub)
        # assert np.all(ub > lb)
        assert max_evals > 0 and isinstance(max_evals, int)
        assert n_init > 0 and isinstance(n_init, int)
        assert batch_size > 0 and isinstance(batch_size, int)
        assert isinstance(verbose, bool) and isinstance(use_ard, bool)
        assert max_cholesky_size >= 0 and isinstance(batch_size, int)
        assert n_training_steps >= 30 and isinstance(n_training_steps, int)
        assert max_evals > n_init and max_evals > batch_size
        assert device == "cpu" or device == "cuda"
        assert dtype == "float32" or dtype == "float64"
        if device == "cuda":
            assert torch.cuda.is_available(), "can't use cuda if it's not available"

        # Save function information
        self.dim = dim
        self.config = config
        self.kwargs = kwargs
        # self.lb = lb
        # self.ub = ub

        # Settings
        self.n_init = n_init
        self.max_evals = max_evals
        self.batch_size = batch_size
        self.verbose = verbose
        self.use_ard = use_ard
        self.max_cholesky_size = max_cholesky_size
        self.n_training_steps = n_training_steps

        self.acq = acq
        self.kernel_type = kernel_type

        # Hyperparameters
        self.mean = np.zeros((0, 1))
        self.signal_var = np.zeros((0, 1))
        self.noise_var = np.zeros((0, 1))
        self.lengthscales = np.zeros((0, self.dim)) if self.use_ard else np.zeros((0, 1))

        # Tolerances and counters
        self.n_cand = kwargs['n_cand'] if 'n_cand' in kwargs.keys() else min(100 * self.dim, 5000)
        # self.failtol = 1 * np.ceil(np.max([4.0 / batch_size, self.dim / batch_size]))
        # self.failtol = 1 * np.ceil(np.max([4.0 / batch_size, self.true_dim / batch_size]))
        # self.failtol = kwargs['failtol'] if 'failtol' in kwargs.keys() else 10
        # self.succtol = kwargs['succtol'] if 'succtol' in kwargs.keys() else 2 # 3
        self.n_evals = 0

        # Trust region sizes (set those to ridiculously large values since they are not used)
        self.length_min = -1e6
        self.length_max = 1e6
        self.length_init = 1e6

        # Trust region sizes (in terms of Hamming distance) of the discrete variables.
        self.length_min_discrete = int(-1e6)
        self.length_max_discrete = int(1e6)
        self.length_init_discrete = int(1e6)

        # # Temperature setting for simulated annealing
        # self.init_temp = 1e-3
        # self.final_temp = 1e-5
        # self.alpha = 1e-6

        # Save the full history
        self.X = np.zeros((0, self.dim))
        self.fX = np.zeros((0, 1))

        # Device and dtype for GPyTorch
        self.min_cuda = min_cuda
        self.dtype = torch.float32 if dtype == "float32" else torch.float64
        self.device = torch.device("cuda") if device == "cuda" else torch.device("cpu")
        if self.verbose:
            print("Using dtype = %s \nUsing device = %s" % (self.dtype, self.device))
            sys.stdout.flush()

        # The categorical and integer dimensions that need special treatments, if any
        # self.cat_dims = cat_dims
        # Initialize parameters
        self._restart()

    def _restart(self):
        self._X = []
        self._fX = []
        # self.failcount = 0
        # self.succcount = 0
        # self.length = self.length_init
        # we are not using the trust region approach here, set them to arbitrarily large values
        self.length = 1e6
        self.length_discrete = int(1e6)
        # self.length_discrete = self.length_init_discrete
        self.gp = None

    def _adjust_length(self, fX_next):
        return
        # if np.min(fX_next) <= np.min(self._fX) - 1e-3 * math.fabs(np.min(self._fX)):
        #     self.succcount += 1
        #     self.failcount = 0
        # else:
        #     self.succcount = 0
        #     self.failcount += 1
        #
        # if self.succcount == self.succtol:  # Expand trust region
        #     self.length = min([2 * self.length, self.length_max])
        #     # For the Hamming distance-bounded trust region, we additively (instead of multiplicatively) adjust.
        #     self.length_discrete = int(min(self.length_discrete * 1.5, self.length_max_discrete))
        #     self.succcount = 0
        #     print("expand", self.length, self.length_discrete)
        # elif self.failcount == self.failtol:  # Shrink trust region
        #     # self.length = max([self.length_min, self.length / 2.0])
        #     self.failcount = 0
        #     # Ditto for shrinking.
        #     self.length_discrete = int(self.length_discrete / 1.5)
        #     print("Shrink", self.length, self.length_discrete)

    # def _create_candidates(self, X, fX, length, n_training_steps, hypers):
    #     """Generate candidates assuming X has been scaled to [0,1]^d."""
    #     # Pick the center as the point with the smallest function values
    #     # NOTE: This may not be robust to noise, in which case the posterior mean of the GP can be used instead
    #     assert X.min() >= 0.0 and X.max() <= 1.0
    #
    #     # Standardize function values.
    #     mu, sigma = np.median(fX), fX.std()
    #     sigma = 1.0 if sigma < 1e-6 else sigma
    #     fX = (deepcopy(fX) - mu) / sigma
    #
    #     # Figure out what device we are running on
    #     if len(X) < self.min_cuda:
    #         device, dtype = torch.device("cpu"), torch.float64
    #     else:
    #         device, dtype = self.device, self.dtype
    #
    #     # We use CG + Lanczos for training if we have enough data
    #     with gpytorch.settings.max_cholesky_size(self.max_cholesky_size):
    #         X_torch = torch.tensor(X).to(device=device, dtype=dtype)
    #         y_torch = torch.tensor(fX).to(device=device, dtype=dtype)
    #         gp = train_gp(
    #             train_x=X_torch, train_y=y_torch, use_ard=self.use_ard, num_steps=n_training_steps, hypers=hypers
    #         )
    #
    #         # Save state dict
    #         hypers = gp.state_dict()
    #
    #     # Create the trust region boundaries
    #
    #     # good for continuous, not categorical
    #     x_center = X[fX.argmin().item(), :][None, :]
    #     weights = gp.covar_module.base_kernel.lengthscale.cpu().detach().numpy().ravel()
    #     weights = weights / weights.mean()  # This will make the next line more stable
    #     weights = weights / np.prod(np.power(weights, 1.0 / len(weights)))  # We now have weights.prod() = 1
    #     lb = np.clip(x_center - weights * length / 2.0, 0.0, 1.0)
    #     ub = np.clip(x_center + weights * length / 2.0, 0.0, 1.0)
    #
    #     # Draw a Sobolev sequence in [lb, ub]
    #     seed = np.random.randint(int(1e6))
    #     sobol = SobolEngine(self.dim, scramble=True, seed=seed)
    #     pert = sobol.draw(self.n_cand).to(dtype=dtype, device=device).cpu().detach().numpy()
    #     pert = lb + (ub - lb) * pert
    #
    #     # Create a perturbation mask
    #     prob_perturb = min(20.0 / self.dim, 1.0)
    #     mask = np.random.rand(self.n_cand, self.dim) <= prob_perturb
    #     ind = np.where(np.sum(mask, axis=1) == 0)[0]
    #     mask[ind, np.random.randint(0, self.dim - 1, size=len(ind))] = 1
    #
    #     # Create candidate points
    #     X_cand = x_center.copy() * np.ones((self.n_cand, self.dim))
    #     X_cand[mask] = pert[mask]
    #
    #     # Figure out what device we are running on
    #     if len(X_cand) < self.min_cuda:
    #         device, dtype = torch.device("cpu"), torch.float64
    #     else:
    #         device, dtype = self.device, self.dtype
    #
    #     # We may have to move the GP to a new device
    #     gp = gp.to(dtype=dtype, device=device)
    #
    #     # We use Lanczos for sampling if we have enough data
    #     with torch.no_grad(), gpytorch.settings.max_cholesky_size(self.max_cholesky_size):
    #         X_cand_torch = torch.tensor(X_cand).to(device=device, dtype=dtype)
    #         y_cand = gp.likelihood(gp(X_cand_torch)).sample(torch.Size([self.batch_size])).t().cpu().detach().numpy()
    #
    #     # Remove the torch variables
    #     del X_torch, y_torch, X_cand_torch, gp
    #
    #     # De-standardize the sampled values
    #     y_cand = mu + sigma * y_cand
    #
    #     return X_cand, y_cand, hypers

    # def _select_candidates(self, X_cand, y_cand):
    #     """Select candidates."""
    #     X_next = np.ones((self.batch_size, self.dim))
    #     for i in range(self.batch_size):
    #         # Pick the best point and make sure we never pick it again
    #         indbest = np.argmin(y_cand[:, i])
    #         X_next[i, :] = deepcopy(X_cand[indbest, :])
    #         y_cand[indbest, :] = np.inf
    #
    #     return X_next

    def _create_and_select_candidates(self, X, fX, length, n_training_steps, hypers, return_acq=False):
        length = 1e6
        # assert X.min() >= 0.0 and X.max() <= 1.0
        # Figure out what device we are running on
        if len(X) < self.min_cuda:
            device, dtype = torch.device("cpu"), torch.float32
        else:
            device, dtype = self.device, self.dtype
        with gpytorch.settings.max_cholesky_size(self.max_cholesky_size):
            X_torch = torch.tensor(X).to(device=device, dtype=dtype)
            y_torch = torch.tensor(fX).to(device=device, dtype=dtype)
            gp = train_gp(
                train_x=X_torch, train_y=y_torch, use_ard=self.use_ard, num_steps=n_training_steps, hypers=hypers,
                kern=self.kernel_type,
                cat_configs=self.config,
                noise_variance=self.kwargs['noise_variance'] if
                'noise_variance' in self.kwargs else None
            )
            # Save state dict
            hypers = gp.state_dict()
            # Save the GP object for later use
            self.gp = gp
        # Standardize function values.
        # mu, sigma = np.median(fX), fX.std()
        # sigma = 1.0 if sigma < 1e-6 else sigma
        # fX = (deepcopy(fX) - mu) / sigma

        from .localbo_utils import simulated_annealing, local_search
        x_center = X[fX.argmin().item(), :][None, :]
        # for global model just randomly sample a starting locaiton
        x0 = random_sample_within_discrete_tr_ordinal(x_center[0], length, self.config)

        def thompson(n_cand=50):
            """Thompson sampling (similar to the original TurBO)"""
            # Generate n_cand of candidates
            X_cand = np.array([
                random_sample_within_discrete_tr_ordinal(x_center[0], length, self.config)
                for _ in range(n_cand)
            ])
            with torch.no_grad(), gpytorch.settings.max_cholesky_size(self.max_cholesky_size):
                X_cand_torch = torch.tensor(X_cand, dtype=torch.float32)
                y_cand = gp.likelihood(gp(X_cand_torch)).sample(torch.Size([self.batch_size])).t().cpu().detach().numpy()
            # Revert the normalization process
            # y_cand = mu + sigma * y_cand

            # debug
            # from localglobal.test_funcs import PestControl
            # import matplotlib.pylab as plt
            # from utils import pearson, spearman
            # f = PestControl(20)
            # true = np.array([f.compute(x) for x in X_cand])
            # plt.subplot(211)
            # plt.plot(true, true, ".")
            # plt.plot(true, gp.likelihood(gp(X_cand_torch)).mean.detach(), ".", color='r')
            # res_val = gp(torch.tensor(X_cand_torch, dtype=torch.float32))
            # mean_, std_ = res_val.mean.detach(), res_val.stddev.detach(),
            # plt.errorbar(true, mean_, yerr=std_, capsize=2, linestyle="None", color='r')
            #
            # plt.subplot(212)
            # plt.plot(y_torch, y_torch, ".", )
            # # plt.plot(y_torch.flatten(),
            # #          gp(torch.tensor(self.X, dtype=torch.float32)).mean.detach(), ".", color='r')
            # res_ = gp(torch.tensor(self.X, dtype=torch.float32))
            # mean, std = res_.mean.detach(), res_.stddev.detach(),
            # plt.errorbar(y_torch.flatten(), mean, yerr=std, linestyle="None", color='r')
            # print(spearman(gp(torch.tensor(self.X, dtype=torch.float32)).mean.detach(), torch.tensor(fX, dtype=torch.float32).flatten(),))
            # print(spearman(true, y_cand.mean(axis=1)))
            # plt.show()

            # Select the best candidates
            X_next = np.ones((self.batch_size, self.dim))
            y_next = np.ones((self.batch_size, 1))
            for i in range(self.batch_size):
                indbest = np.argmin(y_cand[:, i])
                X_next[i, :] = deepcopy(X_cand[indbest, :])
                y_next[i, :] = deepcopy(y_cand[indbest, i])
                y_cand[indbest, :] = np.inf
            return X_next, y_next

        def _ei(X, augmented=True):
            """Expected improvement (with option to enable augmented EI"""
            from torch.distributions import Normal
            if not isinstance(X, torch.Tensor):
                X = torch.tensor(X, dtype=torch.float32)
            if X.dim() == 1:
                X = X.reshape(1, -1)
            gauss = Normal(torch.zeros(1), torch.ones(1))
            # flip for minimization problems
            preds = gp(X)
            mean, std = -preds.mean, preds.stddev
            # use in-fill criterion
            mu_star = -gp.likelihood(gp(torch.tensor(x_center[0].reshape(1, -1), dtype=torch.float32))).mean

            u = (mean - mu_star) / std
            ucdf = gauss.cdf(u)
            updf = torch.exp(gauss.log_prob(u))
            ei = std * updf + (mean - mu_star) * ucdf
            if augmented:
                sigma_n = gp.likelihood.noise
                ei *= (1. - torch.sqrt(torch.tensor(sigma_n)) / torch.sqrt(sigma_n + std ** 2))
            return ei

        def _ucb(X, beta=5.):
            """Upper confidence bound"""
            if not isinstance(X, torch.Tensor):
                X = torch.tensor(X, dtype=torch.float32)
            if X.dim() == 1:
                X = X.reshape(1, -1)
            # Invoked when you supply X in one-hot representations
            # if X.shape[1] == self.dim and self.dim != self.true_dim:
            #     X = onehot2ordinal(X, self.cat_dims)
            preds = gp.likelihood(gp(X))
            mean, std = preds.mean, preds.stddev
            return -(mean + beta * std)
        # X_next = simulated_annealing(x_center[0], _ucb, self.length_discrete, self.cat_dims, self.batch_size,
        #                              self.init_temp, self.final_temp, self.alpha)

        # If conventional acquisition functions are used, the batch setting is a bit complicated whereas for Thompson
        # sampling batch requires no additional special treatments.
        if self.acq in ['ei', 'ucb']:
            if self.batch_size == 1:
                # Sequential setting
                if self.acq == 'ei':
                    X_next, acq_next = local_search(x0, _ei, self.config, length, 1, self.batch_size)
                else:
                    X_next, acq_next = local_search(x0, _ucb, self.config, length, 3, self.batch_size)

            else:
                # batch setting: for these, we use the fantasised points {x, y}
                X_next = torch.tensor([], dtype=torch.float32)
                acq_next = np.array([])
                for p in range(self.batch_size):
                    if self.acq == 'ei':
                        x_next, acq = local_search(x0, _ei, self.config, length, 1, 1)
                    else:
                        x_next, acq = local_search(x0, _ucb, self.config, length, 3, 1)
                    x_next = torch.tensor(x_next, dtype=torch.float32)
                    # The fantasy point is filled by the posterior mean of the Gaussian process.
                    y_next = gp(x_next).mean.detach()
                    with gpytorch.settings.max_cholesky_size(self.max_cholesky_size):
                        X_torch = torch.cat((X_torch, x_next), dim=0)
                        y_torch = torch.cat((y_torch, y_next), dim=0)
                        gp = train_gp(
                            train_x=X_torch, train_y=y_torch, use_ard=self.use_ard, num_steps=n_training_steps,
                            kern=self.kernel_type,
                            cat_configs=self.config,
                            hypers=hypers,
                            noise_variance=self.kwargs['noise_variance'] if
                            'noise_variance' in self.kwargs else None

                        )
                    X_next = torch.cat((X_next, x_next), dim=0)
                    acq_next = np.hstack((acq_next, acq))

        elif self.acq == 'thompson':
            X_next, acq_next = thompson()
        else:
            raise ValueError('Unknown acquisition function choice %s' % self.acq)

        # Remove the torch tensors
        del X_torch, y_torch
        X_next = np.array(X_next)
        if return_acq:
            return X_next, acq_next
        return X_next

    # def optimize(self):
    #     """Run the full optimization process."""
    #     while self.n_evals < self.max_evals:
    #         if len(self._fX) > 0 and self.verbose:
    #             n_evals, fbest = self.n_evals, self._fX.min()
    #             print(f"{n_evals}) Restarting with fbest = {fbest:.4}")
    #             sys.stdout.flush()
    #
    #         # Initialize parameters
    #         self._restart()
    #
    #         # Generate and evalute initial design points
    #         X_init = latin_hypercube(self.n_init, self.dim)
    #         X_init = from_unit_cube(X_init, self.lb, self.ub)
    #         fX_init = np.array([[self.f(x)] for x in X_init])
    #
    #         # Update budget and set as initial data for this TR
    #         self.n_evals += self.n_init
    #         self._X = deepcopy(X_init)
    #         self._fX = deepcopy(fX_init)
    #
    #         # Append data to the global history
    #         self.X = np.vstack((self.X, deepcopy(X_init)))
    #         self.fX = np.vstack((self.fX, deepcopy(fX_init)))
    #
    #         if self.verbose:
    #             fbest = self._fX.min()
    #             print(f"Starting from fbest = {fbest:.4}")
    #             sys.stdout.flush()
    #
    #         # Thompson sample to get next suggestions
    #         while self.n_evals < self.max_evals and self.length >= self.length_min:
    #             # Warp inputs
    #             X = to_unit_cube(deepcopy(self._X), self.lb, self.ub)
    #
    #             # Standardize values
    #             fX = deepcopy(self._fX).ravel()
    #
    #             # Create th next batch
    #             X_cand, y_cand, _ = self._create_candidates(
    #                 X, fX, length=self.length, n_training_steps=self.n_training_steps, hypers={}
    #             )
    #             X_next = self._select_candidates(X_cand, y_cand)
    #
    #             # Undo the warping
    #             X_next = from_unit_cube(X_next, self.lb, self.ub)
    #
    #             # Evaluate batch
    #             fX_next = np.array([[self.f(x)] for x in X_next])
    #
    #             # Update trust region
    #             self._adjust_length(fX_next)
    #
    #             # Update budget and append data
    #             self.n_evals += self.batch_size
    #             self._X = np.vstack((self._X, X_next))
    #             self._fX = np.vstack((self._fX, fX_next))
    #
    #             if self.verbose and fX_next.min() < self.fX.min():
    #                 n_evals, fbest = self.n_evals, fX_next.min()
    #                 print(f"{n_evals}) New best: {fbest:.4}")
    #                 sys.stdout.flush()
    #
    #             # Append data to the global history
    #             self.X = np.vstack((self.X, deepcopy(X_next)))
    #             self.fX = np.vstack((self.fX, deepcopy(fX_next)))
