import time
import math
import gpytorch
from botorch.acquisition import UpperConfidenceBound
from botorch.optim import optimize_acqf
from botorch.fit import fit_gpytorch_model
from src.triggers import Trigger
from typing import Callable
import numpy as np
import torch

from src.models import TimeInvariantGP, TimeVariantGP


class EventTriggeredBO:

    def __init__(self,
                 objective_function: Callable,
                 trigger: Trigger,
                 spatial_dimensions: int,
                 discrete_sample_space=False,
                 bounds: torch.Tensor = torch.tensor([[0], [1]],
                                                     dtype=torch.float),
                 scaling_factors=1.,
                 hyperparameter: dict = {'lengthscales': [0.2],
                                         'noise': 0.01,
                                         'lengthscale_constraint': gpytorch.constraints.Interval(0.3, 4),
                                         'lengthscale_hyperprior': gpytorch.priors.GammaPrior(1.5, 1)},
                 beta_parameter: dict = {'c1': 0.4,
                                         'c2': 4, },
                 use_time_varying_model: bool = False,
                 forgetting_factor: float = 0.03,
                 add_noise=True,
                 learn_hyper_parameters=True,
                 verbose=True,
                 emp_kernel=None):
        self.use_time_varying_model = use_time_varying_model
        self.spatial_dimensions = spatial_dimensions
        self.forgetting_factor = forgetting_factor
        self.function = objective_function
        self.discrete_sample_space = discrete_sample_space
        self.hyperparameter = hyperparameter
        self.learn_hyper_parameters = learn_hyper_parameters
        self.spatial_bounds = bounds
        self.scaling_factors = scaling_factors
        self.spatial_bounds /= self.scaling_factors
        self.trigger = trigger
        self.triggered = []
        self.beta_parameter = beta_parameter
        self.emp_kernel = emp_kernel
        self.eps = 10e-10

        # hyperparameter
        self.noise = hyperparameter['noise']
        self.fixed_lengthscales = False
        if 'lengthscales' in hyperparameter.keys():
            self.lengthscales = hyperparameter['lengthscales']
        self.lengthscale_constraint = hyperparameter['lengthscale_constraint']
        self.lengthscale_hyperprior = hyperparameter['lengthscale_hyperprior']

        # parameters
        self.run_seed = 0
        self.time_horizon = None
        self.verbose = verbose
        self.add_noise = add_noise

    def optimize(self, n_training_points: int = 0,
                 time_horizon: int = 200,
                 model_id: int = 0,
                 iteration: int = 0,
                 break_when_triggered: bool = False):
        self.run_seed = model_id * 10000 + iteration * 1000
        self.time_horizon = time_horizon
        self.triggered = [False for i in range(n_training_points)]
        train_x, train_y, t_remain = self.get_initial_data(n_training_points)
        queries, observations = train_x.clone().detach(), train_y.clone().detach()

        for t in t_remain:
            if t == t_remain[0]:
                model = self.initialize_model(train_x, train_y, t)
            else:
                # initialize updated model for next iteration
                model = self.initialize_model(train_x, train_y, t)
            if self.learn_hyper_parameters and train_x.shape[0] > 0:
                if self.emp_kernel is None or self.use_time_varying_model == "TV_GP_UCB_MLE":
                    model.train()
                    mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
                    try:
                        fit_gpytorch_model(mll)
                    except RuntimeError:
                        print('Runtime Error \t continue with untrained model')
            new_x, new_y = self.get_next_query(model, t, iteration)

            if self.verbose:
                print('')
                print(f'Timestep {t:4.0f} of Trail {iteration:2.0f} (Model ID: {model_id}).')
                print(f'Parameter name: noise value = {model.likelihood.noise_covar.noise.item():0.3f}')
                if self.emp_kernel is None:
                    lengthscales = model.spatial_kernel.lengthscale[0].clone().detach()
                    for i, lengthscale in enumerate(lengthscales):
                        print(f'Parameter name: lengthscale {i} value = {lengthscale.item():0.3f}')

            # add new training points
            train_x = torch.cat((train_x, new_x), dim=0)
            train_y = torch.cat((train_y, new_y))

            queries = torch.cat((queries, new_x), dim=0)
            observations = torch.cat((observations, new_y))

            if self.trigger.check_trigger(model, (new_x, new_y), t):

                train_x, train_y = self.trigger.provide_new_data_set(train_x, train_y, t)
                self.triggered.append(True)
                if break_when_triggered:
                    return queries, observations, self.triggered
            else:
                self.triggered.append(False)

            sw_number = 700
            if self.use_time_varying_model:
                sort_index = torch.argsort(train_x[:, -1], descending=True)
                sorted_x_train = train_x[sort_index, :]
                sorted_y_train = train_y[sort_index]

                if len(sorted_x_train) > sw_number:
                    train_x = sorted_x_train[:sw_number, :]
                    train_y = sorted_y_train[:sw_number]
                    print(f'Data section using sliding window of {sw_number}.')

        return queries.clone().detach(), observations.clone().detach(), self.triggered

    def initialize_model(self, train_x, train_y, t):

        if isinstance(self.lengthscale_constraint, list):
            if len(self.lengthscale_constraint) == self.time_horizon:
                constraint = self.lengthscale_constraint[t - 1]
        else:
            constraint = self.lengthscale_constraint

        if isinstance(self.lengthscale_hyperprior, list):
            if len(self.lengthscale_hyperprior) == self.time_horizon:
                hyperprior = self.lengthscale_hyperprior[t - 1]
        else:
            hyperprior = self.lengthscale_hyperprior

        if self.use_time_varying_model == "TV_GP_UCB":
            model = TimeVariantGP(train_x=train_x,
                                  train_y=train_y,
                                  forgetting_factor=self.forgetting_factor,
                                  lengthscale_constraint=constraint,
                                  lengthscale_hyperprior=hyperprior,
                                  empirical_kernel=self.emp_kernel
                                  )
        elif self.use_time_varying_model == "TV_GP_UCB_MLE":
            model = TimeVariantGP(train_x=train_x,
                                  train_y=train_y,
                                  forgetting_factor=self.forgetting_factor,
                                  type_of_forgetting="TV-MLE",
                                  lengthscale_constraint=constraint,
                                  lengthscale_hyperprior=hyperprior,
                                  empirical_kernel=self.emp_kernel
                                  )
        elif self.use_time_varying_model == "UI_TVBO":  # for the policy search comparison
            model = TimeVariantGP(train_x=train_x,
                                  train_y=train_y,
                                  forgetting_factor=self.forgetting_factor,
                                  type_of_forgetting="UI",
                                  lengthscale_constraint=constraint,
                                  lengthscale_hyperprior=hyperprior,
                                  empirical_kernel=self.emp_kernel
                                  )
        else:
            model = TimeInvariantGP(train_x=train_x,
                                    train_y=train_y,
                                    lengthscale_constraint=constraint,
                                    lengthscale_hyperprior=hyperprior,
                                    empirical_kernel=self.emp_kernel
                                    )

        if not self.learn_hyper_parameters and self.emp_kernel is None:
            if len(self.hyperparameter["lengthscales"]) == self.time_horizon:
                model.spatial_kernel.lengthscale = self.lengthscales[t - 1]
            else:
                if len(self.hyperparameter["lengthscales"]) < model.spatial_kernel.ard_num_dims:
                    model.spatial_kernel.lengthscale = self.lengthscales[0]
                else:
                    model.spatial_kernel.lengthscale = self.lengthscales

        model.likelihood.noise_covar.noise = self.noise
        model.likelihood.noise_covar.raw_noise.requires_grad = False

        return model

    def get_next_query(self, model, t, iteration):
        beta = self.get_beta(model, t)
        if beta < 0:
            beta = self.eps
        if not self.discrete_sample_space:
            acq_func = self.get_acquisition_function(model, beta)
            new_x, new_y = self.optimize_acquisition_function(acq_func, t, iteration)
        else:
            new_x, new_y = self.discrete_acquisition_function_optimization(model, beta, t, iteration)
        return new_x, new_y

    def get_acquisition_function(self, model, beta):
        return UpperConfidenceBound(model=model, beta=beta, )

    def optimize_acquisition_function(self, acq_func, t, iteration):
        bounds = torch.cat((self.spatial_bounds, torch.ones(2, 1) * t), dim=1)

        torch.manual_seed(t + self.run_seed)
        t0_acqf = time.time()
        candidates, _ = optimize_acqf(
            acq_function=acq_func,
            bounds=bounds,
            q=1,
            num_restarts=10,  # 20
            raw_samples=100,
            options={}, )

        if self.verbose:
            print(f'Time for optimizing acquisition functions: {time.time() - t0_acqf:0.3f}s.')

        new_x = candidates.detach()
        new_y = self.function(new_x[:, 0:self.spatial_dimensions] * self.scaling_factors,
                              new_x[0, -1].unsqueeze(0)).reshape(1, 1)
        if self.add_noise:
            new_y += torch.normal(mean=0, std=math.sqrt(self.noise), size=new_y.size())
        return new_x, new_y

    def discrete_acquisition_function_optimization(self, model, beta, t, iteration):
        torch.manual_seed(t + self.run_seed)

        # create x_test
        time = torch.ones_like(self.spatial_bounds[:, 0]) * t
        discrete_points = torch.cat((self.spatial_bounds, time.reshape(-1, 1)), dim=1)

        # get prediction
        model.eval()
        prediction_at_t = model.likelihood(model(discrete_points.to(torch.float32)))
        mean = prediction_at_t.mean.detach()
        stdv = torch.sqrt(prediction_at_t.variance.detach())

        # evaluate acquisition function for all points in the feasible set
        acquisition_function = mean + torch.sqrt(torch.tensor(beta).to(torch.float32)) * stdv

        # select max index (
        max_value = torch.max(acquisition_function)
        max_indices = torch.where(acquisition_function == max_value)[0]
        idx = torch.randperm(len(max_indices))[0]
        chosen_max_idx = max_indices[idx]

        # get new_x and new_y
        new_x = self.spatial_bounds[chosen_max_idx, :].reshape(1, self.spatial_dimensions)
        new_x = torch.cat((new_x, torch.tensor([[t]])), dim=1)
        new_y = self.function(new_x[:, 0:self.spatial_dimensions], new_x[0, -1].unsqueeze(0))
        if self.add_noise:
            new_y += torch.normal(mean=0, std=math.sqrt(self.noise), size=new_y.size())
        return new_x.to(torch.float32), new_y.reshape(1, 1).to(torch.float32)

    def get_beta(self, model, t):
        # parameter from Bogunovic2016
        c1 = self.beta_parameter['c1']
        c2 = self.beta_parameter['c2']

        # choose according to theory
        if c1 == 0 or c2 == 0:
            Lf = model.approximate_Lf(self.trigger.r, self.trigger.delta_T)
            pi_t = math.pi ** 2 * t ** 2 / 6
            first_term = 2 * math.log(2 * pi_t / 0.1)
            second_term = 2 * model.D * math.log(model.D * self.trigger.r * t ** 2 * Lf.item())
            beta = first_term + second_term
            return beta

        return c1 * math.log(c2 * t)

    def reset_data(self, new_data=None):
        if new_data:
            (new_x, new_y) = new_data
            train_x = new_x
            train_y = new_y
        else:
            train_x = torch.empty(0, self.spatial_dimensions + 1)
            train_y = torch.empty(0, 1)

        return train_x, train_y

    def get_initial_data(self, n_training_points, ):
        # np.random.seed(self.run_seed)
        if n_training_points > 0:
            if self.emp_kernel is None:
                initial_data = []
                for i in range(self.spatial_dimensions):
                    x_i = torch.linspace(self.spatial_bounds[0, i], self.spatial_bounds[1, i],
                                         100 + n_training_points)

                    # shuffling using numpy
                    idx_x_i = np.arange(100 + n_training_points)
                    np.random.seed(self.run_seed + i)
                    np.random.shuffle(idx_x_i)

                    train_x_i = x_i[idx_x_i[:n_training_points]]
                    initial_data.append(train_x_i)

                # add temporal dimension and create train_x
                t = torch.arange(0, 0 + n_training_points)
                initial_data.append(t)
                train_x = torch.stack(initial_data, dim=1)

                # create training data
                train_y = self.function(train_x[:, 0:self.spatial_dimensions] * self.scaling_factors,
                                        train_x[:, -1]).reshape(-1, 1)
            else:
                torch.manual_seed(self.run_seed)
                perm = torch.randperm(self.spatial_bounds.size(0))
                idx = perm[:n_training_points]
                samples = self.spatial_bounds[idx].flatten()
                t = torch.arange(1, 1 + n_training_points)
                train_x = torch.stack((samples, t), dim=1)

                # create training data
                train_y = self.function(train_x[:, 0:self.spatial_dimensions] * self.scaling_factors,
                                        train_x[:, -1]).reshape(-1, 1)

            if self.add_noise:
                train_y += torch.normal(mean=0, std=math.sqrt(self.noise), size=train_y.size())  # noisy measurement
        else:
            train_x, train_y = self.reset_data()

        t_remain = torch.arange(1, self.time_horizon + 1)[n_training_points:]
        return train_x, train_y, t_remain