#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import time
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
from botorch.exceptions.errors import OptimizationTimeoutError
from scipy import optimize


def minimize_with_timeout(
    fun: Callable[[np.ndarray, *Any], float],
    x0: np.ndarray,
    args: Tuple[Any, ...] = (),
    method: Optional[str] = None,
    jac: Optional[Union[str, Callable, bool]] = None,
    hess: Optional[Union[str, Callable, optimize.HessianUpdateStrategy]] = None,
    hessp: Optional[Callable] = None,
    bounds: Optional[Union[Sequence[Tuple[float, float]], optimize.Bounds]] = None,
    constraints=(),  # Typing this properly is a s**t job
    tol: Optional[float] = None,
    callback: Optional[Callable] = None,
    options: Optional[Dict[str, Any]] = None,
    timeout_sec: Optional[float] = None,
) -> optimize.OptimizeResult:
    r"""Wrapper around scipy.optimize.minimize to support timeout.

    This method calls scipy.optimize.minimize with all arguments forwarded
    verbatim. The only difference is that if provided a `timeout_sec` argument,
    it will automatically stop the optimziation after the timeout is reached.

    Internally, this is achieved by automatically constructing a wrapper callback
    method that is injected to the scipy.optimize.minimize call and that keeps
    track of the runtime and the optimization variables at the current iteration.
    """
    if timeout_sec:

        start_time = time.monotonic()
        callback_data = {"num_iterations": 0}  # update from withing callback below

        def timeout_callback(xk: np.ndarray) -> bool:
            runtime = time.monotonic() - start_time
            callback_data["num_iterations"] += 1
            if runtime > timeout_sec:
                raise OptimizationTimeoutError(current_x=xk, runtime=runtime)
            return False

        if callback is None:
            wrapped_callback = timeout_callback

        elif callable(method):
            raise NotImplementedError(
                "Custom callable not supported for `method` argument."
            )

        elif method == "trust-constr":  # special signature

            def wrapped_callback(
                xk: np.ndarray, state: optimize.OptimizeResult
            ) -> bool:
                # order here is important to make sure base callback gets executed
                return callback(xk, state) or timeout_callback(xk=xk)

        else:

            def wrapped_callback(xk: np.ndarray) -> None:
                timeout_callback(xk=xk)
                callback(xk)

    else:
        wrapped_callback = callback

    try:
        return optimize.minimize(
            fun=fun,
            x0=x0,
            args=args,
            method=method,
            jac=jac,
            hess=hess,
            hessp=hessp,
            bounds=bounds,
            constraints=constraints,
            tol=tol,
            callback=wrapped_callback,
            options=options,
        )
    except OptimizationTimeoutError as e:
        msg = f"Optimization timed out after {e.runtime} seconds."
        current_fun, *_ = fun(e.current_x, *args)

        return optimize.OptimizeResult(
            fun=current_fun,
            x=e.current_x,
            nit=callback_data["num_iterations"],
            success=False,  # same as when maxiter is reached
            status=1,  # same as when L-BFGS-B reaches maxiter
            message=msg,
        )
