from __future__ import annotations
import numpy as np
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
for i in range(len(physical_devices)):
    try:
      tf.config.experimental.set_memory_growth(physical_devices[i], True)
    except:
      # Invalid device or cannot modify virtual devices once initialized.
      pass
    
import pickle

import numpy as np
import trieste
import gpflow
import networkit
from trieste.objectives.utils import mk_observer
from trieste.space import Box
from trieste.data import Dataset
from trieste.models.gpflow import GaussianProcessRegression, build_gpr, SparseVariational
from trieste.models.gpflow.builders import _get_data_stats, _get_mean_function, _set_gaussian_likelihood_variance, _get_inducing_points
from trieste.acquisition.rule import EfficientGlobalOptimization
import pickle
from trieste.models import ProbabilisticModel
from trieste.acquisition.interface import AcquisitionFunction, SingleModelAcquisitionBuilder, SingleModelVectorizedAcquisitionBuilder, AcquisitionFunctionClass
import scipy
from scipy_fast_optimizer import Scipy_fast

from typing import Any, Callable, List, Optional, Tuple, Union, cast
from trieste.types import TensorType
import tensorflow_probability as tfp
from trieste.models.optimizer import Optimizer

from typing import TypeVar, Union, Tuple, Callable, Optional, Any

InputData = Union[tf.Tensor]
OutputData = Union[tf.Tensor]
RegressionData = Tuple[InputData, OutputData]
Data = TypeVar("Data", RegressionData, InputData, OutputData)
import datetime

class SVGP_Opt(gpflow.models.SVGP):
    
    def training_loss_closure(
        self, data: Data, *, compile=True,
    ) -> Callable[[], tf.Tensor]:
        try:
            return_closure = self.tf_compiled_return_closure
        except:
            
            X, Y = data
            shapespec = (tf.TensorSpec([None, X.shape[-1] + Y.shape[-1]], dtype=tf.float32),)
            compiled = self._training_loss_builder(shapespec)

            def c_inner(data, variables, x):
                print('c_inner debug')
                values = Scipy_fast.unpack_tensors(variables, x)
                Scipy_fast.assign_tensors(variables, values)
                '''
                X = data[0]
                Y = data[1]
                print(X.shape)
                print(Y.shape)
                X_batched = tf.reshape(X, (-1, self.batch_size, tf.shape(X)[-1]))
                Y_batched = tf.reshape(Y, (-1, self.batch_size, tf.shape(Y)[-1]))
                print(X_batched.shape)
                print(Y_batched.shape)
                ccat = tf.concat([X_batched, Y_batched], axis = -1)
                print(ccat.shape)
                v1 = tf.map_fn(lambda x: compiled(x), ccat)
                print(v1.shape)
                '''

                X = data[0]
                Y = data[1]
                print(X.shape)
                print(Y.shape)
                ccat = tf.concat([X, Y], axis = -1)
                v1 = compiled(ccat)
                v1 = tf.math.reduce_sum(v1)
                
                grads = tf.gradients(v1, variables)
                return v1, grads
            
            self.tf_compiled_return_closure = tf.function(c_inner)
            return_closure = self.tf_compiled_return_closure

        def closure(data):
            def c_instantiate(variables, x):
                seq = tf.range(data[0].shape[0], dtype=tf.int32)
                shuf = tf.random.shuffle(seq)
                
                d1 = tf.gather(data[0], shuf)
                d2 = tf.gather(data[1], shuf)
                d1 = tf.reshape(d1, (-1, self.batch_size, tf.shape(d1)[-1]))
                d2 = tf.reshape(d2, (-1, self.batch_size, 1))
                to_out = [self.tf_compiled_return_closure((d1[i], d2[i]), variables, x) for i in range(d1.shape[0])]
                vals_stack = tf.stack([ti[0] for ti in to_out])
                grads_stack = tf.stack([Scipy_fast.pack_tensors(ti[1]) for ti in to_out])
                return tf.math.reduce_sum(vals_stack, axis = 0), tf.math.reduce_sum(grads_stack, axis = 0)

            return c_instantiate
            
        return closure(data)

    def _training_loss_builder(self, shapespec):
        
        @tf.function(experimental_follow_type_hints = True)
        def build(XpY : tf.Tensor):
            X = XpY[:, 0:XpY.shape[-1] -1]
            Y = tf.expand_dims(XpY[:, -1], -1)
            return self._training_loss((X,Y))
        
        return tf.function(build, input_signature = shapespec)
    
    @tf.function(experimental_relax_shapes = True)
    def _training_loss(self, data) -> tf.Tensor:
        a = self.maximum_log_likelihood_objective(data)
        b = self.log_prior_density()
        return -(a+ b)
    


class GPR_Opt(gpflow.models.SVGP):
    def training_loss_closure(
        self, data: Data, *, compile=True,
    ) -> Callable[[], tf.Tensor]:
        batch_size = 400
        X, Y = data
        X_numpy = X.numpy()
        Y_numpy = Y.numpy()

        num_batches = int(X_numpy.shape[0]) // int(batch_size)
        rem = X_numpy.shape[0] % batch_size
        round_up = batch_size - rem
        shuf = np.random.choice(np.arange(0, X_numpy.shape[0], dtype = np.int32), size =X_numpy.shape[0], replace = False)
        X_numpy = X_numpy[shuf]
        Y_numpy = Y_numpy[shuf]

        shuf_rep = np.random.choice(np.arange(0, X_numpy.shape[0], dtype = np.int32), size = round_up, replace = True)
        X_rep = X_numpy[shuf_rep]
        Y_rep = Y_numpy[shuf_rep]

        X_all = np.concatenate([X_numpy, X_rep], axis = 0)
        Y_all = np.concatenate([Y_numpy, Y_rep], axis = 0)
        X_all = X_all.reshape((-1, batch_size, X_numpy.shape[1]))
        Y_all = Y_all.reshape((-1, batch_size, 1))
        
        
        try:
            return_closure = self.tf_compiled_return_closure
            return return_closure(X_all, Y_all)
        except:
            compiled = self._training_loss_builder()
            self.compiled = compiled
            def closure(X,Y):
                def c_inner():
                    return compiled(X, Y)
                return c_inner

            self.tf_compiled_return_closure = closure
            return self.tf_compiled_return_closure(X_all, Y_all)

    def _training_loss_builder(self):
        @tf.function(experimental_relax_shapes = True, experimental_follow_type_hints = True, autograph = False)
        def build(X, Y):
            return tf.math.reduce_sum(tf.map_fn(lambda x,y:self.log_marginal_likelihood(x,y), (X,Y)))

        return tf.function(build, input_signature = shapespec)
    
    @tf.function
    def log_marginal_likelihood(self, X, Y) -> tf.Tensor:
        r"""
        Computes the log marginal likelihood.
        .. math::
            \log p(Y | \theta).
        """
        K = self.kernel(X)
        ks = self._add_noise_cov(K)
        L = tf.linalg.cholesky(ks)
        m = self.mean_function(X)

        # [R,] log-likelihoods for each independent dimension of Y
        log_prob = multivariate_normal(Y, m, L)
        return tf.reduce_sum(log_prob)







class MultipleOptimismNegativeLowerConfidenceBound(
    SingleModelVectorizedAcquisitionBuilder[ProbabilisticModel]
):
    """
    A simple parallelization of the lower confidence bound acquisition function that produces
    a vectorized acquisition function which can efficiently optimized even for large batches.
    See :cite:`torossian2020bayesian` for details.
    """

    def __init__(self, search_space: SearchSpace):
        """
        :param search_space: The global search space over which the optimisation is defined.
        """
        self._search_space = search_space

    def __repr__(self) -> str:
        """"""
        return f"MultipleOptimismNegativeLowerConfidenceBound({self._search_space!r})"

    def prepare_acquisition_function(
        self,
        model: ProbabilisticModel,
        dataset: Optional[Dataset] = None,
    ) -> AcquisitionFunction:
        """
        :param model: The model.
        :param dataset: Unused.
        :return: The multiple optimism negative lower confidence bound function.
        """
        return multiple_optimism_lower_confidence_bound(model, self._search_space.dimension)

    def update_acquisition_function(
        self,
        function: AcquisitionFunction,
        model: ProbabilisticModel,
        dataset: Optional[Dataset] = None,
    ) -> AcquisitionFunction:
        """
        :param function: The acquisition function to update.
        :param model: The model.
        :param dataset: Unused.
        """
        tf.debugging.Assert(isinstance(function, multiple_optimism_lower_confidence_bound), [])
        return function  # nothing to update




def lower_confidence_bound(model: ProbabilisticModel, beta: float) -> AcquisitionFunction:
    r"""
    The lower confidence bound (LCB) acquisition function for single-objective global optimization.
    .. math:: x^* \mapsto \mathbb{E} [f(x^*)|x, y] - \beta \sqrt{ \mathrm{Var}[f(x^*)|x, y] }
    See :cite:`Srinivas:2010` for details.
    :param model: The model of the objective function.
    :param beta: The weight to give to the standard deviation contribution of the LCB. Must not be
        negative.
    :return: The lower confidence bound function. This function will raise
        :exc:`ValueError` or :exc:`~tf.errors.InvalidArgumentError` if used with a batch size
        greater than one.
    :raise tf.errors.InvalidArgumentError: If ``beta`` is negative.
    """
    tf.debugging.assert_non_negative(
        beta, message="Standard deviation scaling parameter beta must not be negative"
    )

    @tf.function
    def acquisition(x: TensorType) -> TensorType:
        tf.debugging.assert_shapes(
            [(x, [..., 1, None])],
            message="This acquisition function only supports batch sizes of one.",
        )
        mean, variance = model.predict(tf.squeeze(x, -2))
        return mean + beta * tf.sqrt(variance)

    return acquisition


class NegativeLowerConfidenceBound_(SingleModelAcquisitionBuilder[ProbabilisticModel]):
    """
    Builder for the negative of the lower confidence bound. The lower confidence bound is typically
    minimised, so the negative is suitable for maximisation.
    """

    def __init__(self, beta: float = 1.96):
        """
        :param beta: Weighting given to the variance contribution to the lower confidence bound.
            Must not be negative.
        """
        self._beta = beta

    def __repr__(self) -> str:
        """"""
        return f"NegativeLowerConfidenceBound({self._beta!r})"

    def prepare_acquisition_function(
        self,
        model: ProbabilisticModel,
        dataset: Optional[Dataset] = None,
    ) -> AcquisitionFunction:
        """
        :param model: The model.
        :param dataset: Unused.
        :return: The negative lower confidence bound function. This function will raise
            :exc:`ValueError` or :exc:`~tf.errors.InvalidArgumentError` if used with a batch size
            greater than one.
        :raise ValueError: If ``beta`` is negative.
        """
        lcb = lower_confidence_bound(model, self._beta)
        return tf.function(lambda at: lcb(at))

    def update_acquisition_function(
        self,
        function: AcquisitionFunction,
        model: ProbabilisticModel,
        dataset: Optional[Dataset] = None,
    ) -> AcquisitionFunction:
        """
        :param function: The acquisition function to update.
        :param model: The model.
        :param dataset: Unused.
        """
        return function  # no need to update anything


class multiple_optimism_lower_confidence_bound(AcquisitionFunctionClass):
    r"""
    The multiple optimism lower confidence bound (MOLCB) acquisition function for single-objective
    global optimization.
    Each batch dimension of this acquisiton function correponds to a lower confidence bound
    acquisition function with different beta values, i.e. each point in a batch chosen by this
    acquisition function lies on a gradient of exploration/exploitation trade-offs.
    We choose the different beta values following the cdf method of :cite:`torossian2020bayesian`.
    See their paper for more details.
    """

    def __init__(self, model: ProbabilisticModel, search_space_dim: int):
        """
        :param model: The model of the objective function.
        :param search_space_dim: The dimensions of the optimisation problem's search space.
        :raise tf.errors.InvalidArgumentError: If ``search_space_dim`` is not postive.
        """

        tf.debugging.assert_positive(search_space_dim)
        self._search_space_dim = search_space_dim

        self._model = model
        self._initialized = tf.Variable(False)  # Keep track of when we need to resample
        self._betas = tf.Variable(tf.ones([0], dtype=tf.float32), shape=[None])  # [0] lazy init

    @tf.function
    def __call__(self, x: TensorType) -> TensorType:

        batch_size = tf.shape(x)[-2]
        tf.debugging.assert_positive(batch_size)

        if self._initialized:  # check batch size hasnt changed during BO
            tf.debugging.assert_equal(
                batch_size,
                tf.shape(self._betas)[0],
                f"{type(self).__name__} requires a fixed batch size. Got batch size {batch_size}"
                f" but previous batch size was {tf.shape(self._betas)[0]}.",
            )

        if not self._initialized:
            normal = tfp.distributions.Normal(
                tf.cast(0.0, dtype=x.dtype), tf.cast(1.0, dtype=x.dtype)
            )
            spread = 0.5 + 0.5 * tf.range(1, batch_size + 1, dtype=x.dtype) / (
                tf.cast(batch_size, dtype=x.dtype) + 1.0
            )  # [B]
            betas = normal.quantile(spread)  # [B]
            scaled_betas = 2.8 * tf.cast(self._search_space_dim, dtype=x.dtype) * betas  # [B]
            self._betas.assign(scaled_betas)  # [B]
            self._initialized.assign(True)

        x_res = tf.reshape(x, (-1, x.shape[-1]))

        mean, variance = self._model.predict(x_res)  # [..., B, 1]
        mean = tf.reshape(mean, (-1, batch_size, 1))
        variance = tf.reshape(variance, (-1, batch_size, 1))
        mean, variance = tf.squeeze(mean, -1), tf.squeeze(variance, -1)
        return mean + tf.sqrt(variance) * self._betas  # [..., B]
