import abc
import math
import torch
import gpytorch


def get_trigger(name, options, trigger_parameter):
    if name == "GP_UCB":
        return GP_UCB_Trigger(options, trigger_parameter)
    elif name == "TV_GP_UCB":
        return TV_GP_UCB_Trigger(options, trigger_parameter)
    elif name == "TV_GP_UCB_MLE":
        return TV_GP_UCB_Trigger(options, trigger_parameter)
    elif name == "UI_TVBO":
        return UI_TVBO_Trigger(options, trigger_parameter)
    elif name == "R_GP_UCB":
        return R_GP_UCB_Trigger(options, trigger_parameter)
    elif name == "ET_GP_UCB":
        return ET_GP_UCB_Trigger(options, trigger_parameter)
    elif "ET_GP_UCB_" in name:
        return ET_GP_UCB_TheoryTrigger(options, trigger_parameter)
    elif name == "Cyclic_Trigger":
        return Cyclic_Trigger(options, trigger_parameter)
    elif name == "ET_GP_UCB_backtracking":
        return ET_GP_UCB_Backtracking_Trigger(options, trigger_parameter)
    else:
        raise NotImplementedError


class Trigger(abc.ABC):

    def __init__(self, options, trigger_parameter):
        self.options = options
        self.trigger_parameter = trigger_parameter
        self.model = None
        self.eps = 1e-4
        self.n_training_points = None
        self.Lf = None
        self.Lk = None
        self.beta = None

        # set delta of P(.) >= 1 - \delta
        self.r = options['r']
        self.delta_T = options['delta_T']

    @abc.abstractmethod
    def check_trigger(self, model: gpytorch.models.GP, new_data: tuple, t, return_bound=False) -> bool:
        pass

    def provide_new_data_set(self, train_x, train_y, t):
        return torch.empty(0, train_x.shape[1]), torch.empty(0, train_y.shape[1])

    def calculate_loglikelihood(self, mean, var, y):
        normal = torch.distributions.Normal(mean, torch.sqrt(var))
        return normal.log_prob(y)

    def get_prediction(self, new_x):
        self.model.eval()
        self.model.likelihood.eval()
        pred = self.model(new_x)
        return pred.mean.detach(), pred.variance.detach()

    def calculate_likelihood(self, mean, var, y):
        return self.calculate_loglikelihood(mean, var, y).exp()


class R_GP_UCB_Trigger(Trigger):

    def __init__(self, options, trigger_parameter):
        super().__init__(options, trigger_parameter)
        time_horizon = self.options["time_horizon"]
        N = min(time_horizon, 12 * trigger_parameter ** (-0.25))
        self.N = int(round(N + 0.5))

    def check_trigger(self, model, new_data, t, return_bound=False):
        # trigger every N timesteps
        if t % self.N == 0:
            return True
        else:
            return False


class Cyclic_Trigger(Trigger):

    def __init__(self, options, trigger_parameter):
        super().__init__(options, trigger_parameter)
        const_optimal = self.options["const_optimal"]
        self.N = const_optimal

    def check_trigger(self, model, new_data, t, return_bound=False):
        # trigger every N timesteps
        if t % self.N == 0:
            return True
        else:
            return False


class GP_UCB_Trigger(Trigger):
    def check_trigger(self, model, new_data, t, return_bound=False):
        return False


class TV_GP_UCB_Trigger(Trigger):
    def check_trigger(self, model, new_data, t, return_bound=False):
        return False


class UI_TVBO_Trigger(Trigger):
    def check_trigger(self, model, new_data, t, return_bound=False):
        return False


class ET_GP_UCB_TheoryTrigger(Trigger):

    def __init__(self, options, trigger_parameter):
        super().__init__(options, trigger_parameter)

        if "N_lower" in self.options.keys():
            self.N_lower = self.options["N_lower"]
        else:
            raise ValueError("N_lower not in options")

        if "N_upper" in self.options.keys():
            self.N_upper = self.options["N_upper"]
        else:
            raise ValueError("N_lower not in options")

        assert self.N_lower <= self.N_upper, "N_lower must be smaller than N_upper"

    def reset(self):
        self.model = None
        self.beta = None

    def _get_rho(self):
        pi_t = math.pi ** 2 * self.n_training_points ** 2 / 6
        out = 2 * math.log(2 * pi_t / (1 * self.delta_T))
        return out

    def provide_new_data_set(self, train_x, train_y, t):
        return train_x[-1:, :], train_y[-1:, :]

    def check_trigger(self, model, new_data, t, return_bound=False):

        # first iteration can not result in a reset
        if t == 1:
            return False

        # explicitly reset everything from last iteration (just to be sure)
        self.reset()

        # get basic stuff
        self.model = model
        self.n_training_points = len(self.model.train_targets)

        if self.n_training_points < self.N_lower:
            return False
        if self.n_training_points >= self.N_upper:
            return True
        else:
            new_x, new_y = new_data

            (mean_pred, var_pred) = self.get_prediction(new_x)
            noise = self.model.likelihood.noise_covar.noise.item()
            rho_t = self._get_rho()
            upper_bound = torch.sqrt(rho_t * var_pred)
            pi_t = math.pi ** 2 * self.n_training_points ** 2 / 6
            noise_bound = math.sqrt(2 * noise * math.log(2 * pi_t / self.delta_T))
            upper_bound += noise_bound

            if return_bound:
                return upper_bound

            abs_diff = torch.abs(new_y - mean_pred)

            # evaluate confidence bound
            if abs_diff <= upper_bound:
                return False
            else:
                return True


class ET_GP_UCB_Trigger(Trigger):

    def __init__(self, options, trigger_parameter):
        super().__init__(options, trigger_parameter)

    def reset(self):
        self.model = None
        self.beta = None

    def _get_rho(self):
        pi_t = math.pi ** 2 * self.n_training_points ** 2 / 6
        out = 2 * math.log(2 * pi_t / (1 * self.delta_T))
        return out

    def check_trigger(self, model, new_data, t, return_bound=False):

        # first iteration can not result in a reset
        if t == 1:
            return False

        # explicitly reset everything from last iteration (just to be sure)
        self.reset()

        # get basic stuff
        self.model = model
        self.n_training_points = len(self.model.train_targets)

        if self.n_training_points == 0:
            return False

        new_x, new_y = new_data

        (mean_pred, var_pred) = self.get_prediction(new_x)
        noise = self.model.likelihood.noise_covar.noise.item()
        rho_t = self._get_rho()
        upper_bound = torch.sqrt(rho_t * var_pred)
        pi_t = math.pi ** 2 * self.n_training_points ** 2 / 6
        noise_bound = math.sqrt(2 * noise * math.log(2 * pi_t / self.delta_T))
        upper_bound += noise_bound

        if return_bound:
            return upper_bound

        abs_diff = torch.abs(new_y - mean_pred)

        # evaluate confidence bound
        if abs_diff <= upper_bound:
            return False
        else:
            return True


class ET_GP_UCB_Backtracking_Trigger(ET_GP_UCB_Trigger):

    def provide_new_data_set(self, train_x, train_y, t):
        # new_train_x, new_train_y = torch.empty(0, train_x.shape[1]), torch.empty(0, train_y.shape[1])

        # the first data point has to be in the new data set!
        new_train_x, new_train_y = train_x[-1:, :], train_y[-1:, :]
        max_length = 2 * self.model.D
        for i in range(1, train_x.shape[0]):
            candidate_train_x, candidate_train_y = train_x[-i - 1:-i, :], train_y[-i - 1:-i, :]
            self.model.set_train_data(
                inputs=new_train_x, targets=new_train_y.squeeze(-1), strict=False)
            not_compatible = self.check_trigger(
                self.model, (candidate_train_x, candidate_train_y), t, )
            if not_compatible:
                # if no longer compatible -> break and keep the obtained data set
                break
            else:
                new_train_x = torch.cat(
                    (candidate_train_x, new_train_x), dim=0)
                new_train_y = torch.cat((candidate_train_y, new_train_y))
        print("New data set size: ", new_train_x.shape[0])
        return new_train_x, new_train_y