from typing import Callable, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import scipy.optimize
import tensorflow as tf
from scipy.optimize import OptimizeResult
from gpflow.optimizers import Scipy
from gpflow.monitor.base import Monitor


Variables = Iterable[tf.Variable]  # deprecated
StepCallback = Union[Callable[[int, Sequence[tf.Variable], Sequence[tf.Tensor]], None], Monitor]
LossClosure = Callable[[], tf.Tensor]

class Scipy_fast(Scipy):
    def minimize(
        self,
        closure: LossClosure,
        variables: Sequence[tf.Variable],
        method: Optional[str] = "L-BFGS-B",
        step_callback: Optional[StepCallback] = None,
        compile: bool = True,
        **scipy_kwargs,
    ) -> OptimizeResult:
        """
        Minimize is a wrapper around the `scipy.optimize.minimize` function
        handling the packing and unpacking of a list of shaped variables on the
        TensorFlow side vs. the flat numpy array required on the Scipy side.

        Args:
            closure: A closure that re-evaluates the model, returning the loss
                to be minimized.
            variables: The list (tuple) of variables to be optimized
                (typically `model.trainable_variables`)
            method: The type of solver to use in SciPy. Defaults to "L-BFGS-B".
            step_callback: If not None, a callable that gets called once after
                each optimisation step. The callable is passed the arguments
                `step`, `variables`, and `values`. `step` is the optimisation
                step counter, `variables` is the list of trainable variables as
                above, and `values` is the corresponding list of tensors of
                matching shape that contains their value at this optimisation
                step.
            compile: If True, wraps the evaluation function (the passed `closure`
                as well as its gradient computation) inside a `tf.function()`,
                which will improve optimization speed in most cases.

            scipy_kwargs: Arguments passed through to `scipy.optimize.minimize`
                Note that Scipy's minimize() takes a `callback` argument, but
                you probably want to use our wrapper and pass in `step_callback`.

        Returns:
            The optimization result represented as a Scipy ``OptimizeResult``
            object. See the Scipy documentation for description of attributes.
        """
        if not callable(closure):
            raise TypeError(
                "The 'closure' argument is expected to be a callable object."
            )  # pragma: no cover
        variables = tuple(variables)
        if not all(isinstance(v, tf.Variable) for v in variables):
            raise TypeError(
                "The 'variables' argument is expected to only contain tf.Variable instances (use model.trainable_variables, not model.trainable_parameters)"
            )  # pragma: no cover
        initial_params = self.initial_parameters(variables)

        func = self.eval_func(closure, variables, compile=compile)
        if step_callback is not None:
            if "callback" in scipy_kwargs:
                raise ValueError("Callback passed both via `step_callback` and `callback`")

            callback = self.callback_func(variables, step_callback)
            scipy_kwargs.update(dict(callback=callback))

        return scipy.optimize.minimize(
            func, initial_params, jac=True, method=method, **scipy_kwargs
        )
    @classmethod
    def eval_func(
        cls, closure: LossClosure, variables: Sequence[tf.Variable], compile: bool = True
    ) -> Callable[[np.ndarray], Tuple[np.ndarray, np.ndarray]]:
        def _tf_eval(x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
            loss, grads = closure(variables, x)
            return loss, grads

        #if compile:
        #    _tf_eval = tf.function(closure)

        def _eval(x: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
            loss, grad = _tf_eval(tf.convert_to_tensor(x))
            return loss.numpy().astype(np.float64), grad.numpy().astype(np.float64)

        return _eval
