# Copyright 2020 The Ray Authors.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This source file is adapted here because ray does not fully support Windows.

# Copyright (c) Microsoft Corporation.
import time
import functools
import warnings
import copy
import logging
from typing import Any, Dict, Optional, Union, List, Tuple, Callable
import pickle
from .variant_generator import parse_spec_vars
from ..sample import (
    Categorical,
    Domain,
    Float,
    Integer,
    LogUniform,
    Quantized,
    Uniform,
)
from ..trial import flatten_dict, unflatten_dict

logger = logging.getLogger(__name__)

UNRESOLVED_SEARCH_SPACE = str(
    "You passed a `{par}` parameter to {cls} that contained unresolved search "
    "space definitions. {cls} should however be instantiated with fully "
    "configured search spaces only. To use Ray Tune's automatic search space "
    "conversion, pass the space definition as part of the `config` argument "
    "to `tune.run()` instead."
)

UNDEFINED_SEARCH_SPACE = str(
    "Trying to sample a configuration from {cls}, but no search "
    "space has been defined. Either pass the `{space}` argument when "
    "instantiating the search algorithm, or pass a `config` to "
    "`tune.run()`."
)

UNDEFINED_METRIC_MODE = str(
    "Trying to sample a configuration from {cls}, but the `metric` "
    "({metric}) or `mode` ({mode}) parameters have not been set. "
    "Either pass these arguments when instantiating the search algorithm, "
    "or pass them to `tune.run()`."
)


class Searcher:
    """Abstract class for wrapping suggesting algorithms.
    Custom algorithms can extend this class easily by overriding the
    `suggest` method provide generated parameters for the trials.
    Any subclass that implements ``__init__`` must also call the
    constructor of this class: ``super(Subclass, self).__init__(...)``.
    To track suggestions and their corresponding evaluations, the method
    `suggest` will be passed a trial_id, which will be used in
    subsequent notifications.
    Not all implementations support multi objectives.
    Args:
        metric (str or list): The training result objective value attribute. If
            list then list of training result objective value attributes
        mode (str or list): If string One of {min, max}. If list then
            list of max and min, determines whether objective is minimizing
            or maximizing the metric attribute. Must match type of metric.

    ```python
    class ExampleSearch(Searcher):
        def __init__(self, metric="mean_loss", mode="min", **kwargs):
            super(ExampleSearch, self).__init__(
                metric=metric, mode=mode, **kwargs)
            self.optimizer = Optimizer()
            self.configurations = {}
        def suggest(self, trial_id):
            configuration = self.optimizer.query()
            self.configurations[trial_id] = configuration
        def on_trial_complete(self, trial_id, result, **kwargs):
            configuration = self.configurations[trial_id]
            if result and self.metric in result:
                self.optimizer.update(configuration, result[self.metric])
    tune.run(trainable_function, search_alg=ExampleSearch())
    ```

    """

    FINISHED = "FINISHED"
    CKPT_FILE_TMPL = "searcher-state-{}.pkl"

    def __init__(
        self,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        max_concurrent: Optional[int] = None,
        use_early_stopped_trials: Optional[bool] = None,
    ):
        self._metric = metric
        self._mode = mode

        if not mode or not metric:
            # Early return to avoid assertions
            return

        assert isinstance(
            metric, type(mode)
        ), "metric and mode must be of the same type"
        if isinstance(mode, str):
            assert mode in ["min", "max"], "if `mode` is a str must be 'min' or 'max'!"
        elif isinstance(mode, list):
            assert len(mode) == len(metric), "Metric and mode must be the same length"
            assert all(
                mod in ["min", "max", "obs"] for mod in mode
            ), "All of mode must be 'min' or 'max' or 'obs'!"
        else:
            raise ValueError("Mode must either be a list or string")

    def set_search_properties(
        self, metric: Optional[str], mode: Optional[str], config: Dict
    ) -> bool:
        """Pass search properties to searcher.
        This method acts as an alternative to instantiating search algorithms
        with their own specific search spaces. Instead they can accept a
        Tune config through this method. A searcher should return ``True``
        if setting the config was successful, or ``False`` if it was
        unsuccessful, e.g. when the search space has already been set.
        Args:
            metric (str): Metric to optimize
            mode (str): One of ["min", "max"]. Direction to optimize.
            config (dict): Tune config dict.
        """
        return False

    def on_trial_result(self, trial_id: str, result: Dict):
        """Optional notification for result during training.
        Note that by default, the result dict may include NaNs or
        may not include the optimization metric. It is up to the
        subclass implementation to preprocess the result to
        avoid breaking the optimization process.
        Args:
            trial_id (str): A unique string ID for the trial.
            result (dict): Dictionary of metrics for current training progress.
                Note that the result dict may include NaNs or
                may not include the optimization metric. It is up to the
                subclass implementation to preprocess the result to
                avoid breaking the optimization process.
        """
        pass

    @property
    def metric(self) -> str:
        """The training result objective value attribute."""
        return self._metric

    @property
    def mode(self) -> str:
        """Specifies if minimizing or maximizing the metric."""
        return self._mode


class ConcurrencyLimiter(Searcher):
    """A wrapper algorithm for limiting the number of concurrent trials.
    Args:
        searcher (Searcher): Searcher object that the
            ConcurrencyLimiter will manage.
        max_concurrent (int): Maximum concurrent samples from the underlying
            searcher.
        batch (bool): Whether to wait for all concurrent samples
            to finish before updating the underlying searcher.
    Example:
    ```python
    from ray.tune.suggest import ConcurrencyLimiter  # ray version < 2
    search_alg = HyperOptSearch(metric="accuracy")
    search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
    tune.run(trainable, search_alg=search_alg)
    ```
    """

    def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
        assert type(max_concurrent) is int and max_concurrent > 0
        self.searcher = searcher
        self.max_concurrent = max_concurrent
        self.batch = batch
        self.live_trials = set()
        self.cached_results = {}
        super(ConcurrencyLimiter, self).__init__(
            metric=self.searcher.metric, mode=self.searcher.mode
        )

    def suggest(self, trial_id: str) -> Optional[Dict]:
        assert (
            trial_id not in self.live_trials
        ), f"Trial ID {trial_id} must be unique: already found in set."
        if len(self.live_trials) >= self.max_concurrent:
            logger.debug(
                f"Not providing a suggestion for {trial_id} due to "
                "concurrency limit: %s/%s.",
                len(self.live_trials),
                self.max_concurrent,
            )
            return

        suggestion = self.searcher.suggest(trial_id)
        if suggestion not in (None, Searcher.FINISHED):
            self.live_trials.add(trial_id)
        return suggestion

    def on_trial_complete(
        self, trial_id: str, result: Optional[Dict] = None, error: bool = False
    ):
        if trial_id not in self.live_trials:
            return
        elif self.batch:
            self.cached_results[trial_id] = (result, error)
            if len(self.cached_results) == self.max_concurrent:
                # Update the underlying searcher once the
                # full batch is completed.
                for trial_id, (result, error) in self.cached_results.items():
                    self.searcher.on_trial_complete(
                        trial_id, result=result, error=error
                    )
                    self.live_trials.remove(trial_id)
                self.cached_results = {}
            else:
                return
        else:
            self.searcher.on_trial_complete(trial_id, result=result, error=error)
            self.live_trials.remove(trial_id)

    def get_state(self) -> Dict:
        state = self.__dict__.copy()
        del state["searcher"]
        return copy.deepcopy(state)

    def set_state(self, state: Dict):
        self.__dict__.update(state)

    def save(self, checkpoint_path: str):
        self.searcher.save(checkpoint_path)

    def restore(self, checkpoint_path: str):
        self.searcher.restore(checkpoint_path)

    def on_pause(self, trial_id: str):
        self.searcher.on_pause(trial_id)

    def on_unpause(self, trial_id: str):
        self.searcher.on_unpause(trial_id)

    def set_search_properties(
        self, metric: Optional[str], mode: Optional[str], config: Dict
    ) -> bool:
        return self.searcher.set_search_properties(metric, mode, config)


try:
    import optuna as ot
    from optuna.distributions import BaseDistribution as OptunaDistribution
    from optuna.samplers import BaseSampler
    from optuna.trial import TrialState as OptunaTrialState
    from optuna.trial import Trial as OptunaTrial
except ImportError:
    ot = None
    OptunaDistribution = None
    BaseSampler = None
    OptunaTrialState = None
    OptunaTrial = None

# (Optional) Default (anonymous) metric when using tune.report(x)
DEFAULT_METRIC = "_metric"

# (Auto-filled) The index of this training iteration.
TRAINING_ITERATION = "training_iteration"

# print a warning if define by run function takes longer than this to execute
DEFINE_BY_RUN_WARN_THRESHOLD_S = 1  # 1 is arbitrary


def validate_warmstart(
    parameter_names: List[str],
    points_to_evaluate: List[Union[List, Dict]],
    evaluated_rewards: List,
    validate_point_name_lengths: bool = True,
):
    """Generic validation of a Searcher's warm start functionality.
    Raises exceptions in case of type and length mismatches between
    parameters.
    If ``validate_point_name_lengths`` is False, the equality of lengths
    between ``points_to_evaluate`` and ``parameter_names`` will not be
    validated.
    """
    if points_to_evaluate:
        if not isinstance(points_to_evaluate, list):
            raise TypeError(
                "points_to_evaluate expected to be a list, got {}.".format(
                    type(points_to_evaluate)
                )
            )
        for point in points_to_evaluate:
            if not isinstance(point, (dict, list)):
                raise TypeError(
                    f"points_to_evaluate expected to include list or dict, "
                    f"got {point}."
                )

            if validate_point_name_lengths and (not len(point) == len(parameter_names)):
                raise ValueError(
                    "Dim of point {}".format(point)
                    + " and parameter_names {}".format(parameter_names)
                    + " do not match."
                )

    if points_to_evaluate and evaluated_rewards:
        if not isinstance(evaluated_rewards, list):
            raise TypeError(
                "evaluated_rewards expected to be a list, got {}.".format(
                    type(evaluated_rewards)
                )
            )
        if not len(evaluated_rewards) == len(points_to_evaluate):
            raise ValueError(
                "Dim of evaluated_rewards {}".format(evaluated_rewards)
                + " and points_to_evaluate {}".format(points_to_evaluate)
                + " do not match."
            )


class _OptunaTrialSuggestCaptor:
    """Utility to capture returned values from Optuna's suggest_ methods.
    This will wrap around the ``optuna.Trial` object and decorate all
    `suggest_` callables with a function capturing the returned value,
    which will be saved in the ``captured_values`` dict.
    """

    def __init__(self, ot_trial: OptunaTrial) -> None:
        self.ot_trial = ot_trial
        self.captured_values: Dict[str, Any] = {}

    def _get_wrapper(self, func: Callable) -> Callable:
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # name is always the first arg for suggest_ methods
            name = kwargs.get("name", args[0])
            ret = func(*args, **kwargs)
            self.captured_values[name] = ret
            return ret

        return wrapper

    def __getattr__(self, item_name: str) -> Any:
        item = getattr(self.ot_trial, item_name)
        if item_name.startswith("suggest_") and callable(item):
            return self._get_wrapper(item)
        return item


class OptunaSearch(Searcher):
    """A wrapper around Optuna to provide trial suggestions.
        [Optuna](https://optuna.org/)
        is a hyperparameter optimization library.
        In contrast to other libraries, it employs define-by-run style
        hyperparameter definitions.
        This Searcher is a thin wrapper around Optuna's search algorithms.
        You can pass any Optuna sampler, which will be used to generate
        hyperparameter suggestions.
        Args:
            space (dict|Callable): Hyperparameter search space definition for
                Optuna's sampler. This can be either a class `dict` with
                parameter names as keys and ``optuna.distributions`` as values,
                or a Callable - in which case, it should be a define-by-run
                function using ``optuna.trial`` to obtain the hyperparameter
                values. The function should return either a class `dict` of
                constant values with names as keys, or None.
                For more information, see
                [tutorial](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html).
                Warning - No actual computation should take place in the define-by-run
                function. Instead, put the training logic inside the function
                or class trainable passed to tune.run.
            metric (str): The training result objective value attribute. If None
                but a mode was passed, the anonymous metric `_metric` will be used
                per default.
            mode (str): One of {min, max}. Determines whether objective is
                minimizing or maximizing the metric attribute.
            points_to_evaluate (list): Initial parameter suggestions to be run
                first. This is for when you already have some good parameters
                you want to run first to help the algorithm make better suggestions
                for future parameters. Needs to be a list of dicts containing the
                configurations.
            sampler (optuna.samplers.BaseSampler): Optuna sampler used to
                draw hyperparameter configurations. Defaults to ``TPESampler``.
            seed (int): Seed to initialize sampler with. This parameter is only
                used when ``sampler=None``. In all other cases, the sampler
                you pass should be initialized with the seed already.
            evaluated_rewards (list): If you have previously evaluated the
                parameters passed in as points_to_evaluate you can avoid
                re-running those trials by passing in the reward attributes
                as a list so the optimiser can be told the results without
                needing to re-compute the trial. Must be the same length as
                points_to_evaluate.

        Tune automatically converts search spaces to Optuna's format:

    ````python
    from ray.tune.suggest.optuna import OptunaSearch  # ray version < 2
    config = { "a": tune.uniform(6, 8),
               "b": tune.loguniform(1e-4, 1e-2)}
    optuna_search = OptunaSearch(metric="loss", mode="min")
    tune.run(trainable, config=config, search_alg=optuna_search)
    ````

        If you would like to pass the search space manually, the code would
        look like this:

    ```python
    from ray.tune.suggest.optuna import OptunaSearch  # ray version < 2
    import optuna
    config = { "a": optuna.distributions.UniformDistribution(6, 8),
               "b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2)}
    optuna_search = OptunaSearch(space,metric="loss",mode="min")
    tune.run(trainable, search_alg=optuna_search)
    # Equivalent Optuna define-by-run function approach:
    def define_search_space(trial: optuna.Trial):
        trial.suggest_float("a", 6, 8)
        trial.suggest_float("b", 1e-4, 1e-2, log=True)
        # training logic goes into trainable, this is just
        # for search space definition
    optuna_search = OptunaSearch(
        define_search_space,
        metric="loss",
        mode="min")
    tune.run(trainable, search_alg=optuna_search)
    .. versionadded:: 0.8.8
    ```

    """

    def __init__(
        self,
        space: Optional[
            Union[
                Dict[str, "OptunaDistribution"],
                List[Tuple],
                Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
            ]
        ] = None,
        metric: Optional[str] = None,
        mode: Optional[str] = None,
        points_to_evaluate: Optional[List[Dict]] = None,
        sampler: Optional["BaseSampler"] = None,
        seed: Optional[int] = None,
        evaluated_rewards: Optional[List] = None,
    ):
        assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
        super(OptunaSearch, self).__init__(
            metric=metric, mode=mode, max_concurrent=None, use_early_stopped_trials=None
        )

        if isinstance(space, dict) and space:
            resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
            if domain_vars or grid_vars:
                logger.warning(
                    UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self).__name__)
                )
                space = self.convert_search_space(space)
            else:
                # Flatten to support nested dicts
                space = flatten_dict(space, "/")

        self._space = space

        self._points_to_evaluate = points_to_evaluate or []
        self._evaluated_rewards = evaluated_rewards

        self._study_name = "optuna"  # Fixed study name for in-memory storage

        if sampler and seed:
            logger.warning(
                "You passed an initialized sampler to `OptunaSearch`. The "
                "`seed` parameter has to be passed to the sampler directly "
                "and will be ignored."
            )

        self._sampler = sampler or ot.samplers.TPESampler(seed=seed)

        assert isinstance(self._sampler, BaseSampler), (
            "You can only pass an instance of `optuna.samplers.BaseSampler` "
            "as a sampler to `OptunaSearcher`."
        )

        self._ot_trials = {}
        self._ot_study = None
        if self._space:
            self._setup_study(mode)

    def _setup_study(self, mode: str):
        if self._metric is None and self._mode:
            # If only a mode was passed, use anonymous metric
            self._metric = DEFAULT_METRIC

        pruner = ot.pruners.NopPruner()
        storage = ot.storages.InMemoryStorage()

        self._ot_study = ot.study.create_study(
            storage=storage,
            sampler=self._sampler,
            pruner=pruner,
            study_name=self._study_name,
            direction="minimize" if mode == "min" else "maximize",
            load_if_exists=True,
        )

        if self._points_to_evaluate:
            validate_warmstart(
                self._space,
                self._points_to_evaluate,
                self._evaluated_rewards,
                validate_point_name_lengths=not callable(self._space),
            )
            if self._evaluated_rewards:
                for point, reward in zip(
                    self._points_to_evaluate, self._evaluated_rewards
                ):
                    self.add_evaluated_point(point, reward)
            else:
                for point in self._points_to_evaluate:
                    self._ot_study.enqueue_trial(point)

    def set_search_properties(
        self, metric: Optional[str], mode: Optional[str], config: Dict
    ) -> bool:
        if self._space:
            return False
        space = self.convert_search_space(config)
        self._space = space
        if metric:
            self._metric = metric
        if mode:
            self._mode = mode

        self._setup_study(mode)
        return True

    def _suggest_from_define_by_run_func(
        self,
        func: Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
        ot_trial: "OptunaTrial",
    ) -> Dict:
        captor = _OptunaTrialSuggestCaptor(ot_trial)
        time_start = time.time()
        ret = func(captor)
        time_taken = time.time() - time_start
        if time_taken > DEFINE_BY_RUN_WARN_THRESHOLD_S:
            warnings.warn(
                "Define-by-run function passed in the `space` argument "
                f"took {time_taken} seconds to "
                "run. Ensure that actual computation, training takes "
                "place inside Tune's train functions or Trainables "
                "passed to `tune.run`."
            )
        if ret is not None:
            if not isinstance(ret, dict):
                raise TypeError(
                    "The return value of the define-by-run function "
                    "passed in the `space` argument should be "
                    "either None or a `dict` with `str` keys. "
                    f"Got {type(ret)}."
                )
            if not all(isinstance(k, str) for k in ret.keys()):
                raise TypeError(
                    "At least one of the keys in the dict returned by the "
                    "define-by-run function passed in the `space` argument "
                    "was not a `str`."
                )
        return {**captor.captured_values, **ret} if ret else captor.captured_values

    def suggest(self, trial_id: str) -> Optional[Dict]:
        if not self._space:
            raise RuntimeError(
                UNDEFINED_SEARCH_SPACE.format(
                    cls=self.__class__.__name__, space="space"
                )
            )
        if not self._metric or not self._mode:
            raise RuntimeError(
                UNDEFINED_METRIC_MODE.format(
                    cls=self.__class__.__name__, metric=self._metric, mode=self._mode
                )
            )

        if isinstance(self._space, list):
            # Keep for backwards compatibility
            # Deprecate: 1.5
            if trial_id not in self._ot_trials:
                self._ot_trials[trial_id] = self._ot_study.ask()

            ot_trial = self._ot_trials[trial_id]

            # getattr will fetch the trial.suggest_ function on Optuna trials
            params = {
                args[0]
                if len(args) > 0
                else kwargs["name"]: getattr(ot_trial, fn)(*args, **kwargs)
                for (fn, args, kwargs) in self._space
            }
        elif callable(self._space):
            if trial_id not in self._ot_trials:
                self._ot_trials[trial_id] = self._ot_study.ask()

            ot_trial = self._ot_trials[trial_id]

            params = self._suggest_from_define_by_run_func(self._space, ot_trial)
        else:
            # Use Optuna ask interface (since version 2.6.0)
            if trial_id not in self._ot_trials:
                self._ot_trials[trial_id] = self._ot_study.ask(
                    fixed_distributions=self._space
                )
            ot_trial = self._ot_trials[trial_id]
            params = ot_trial.params

        return unflatten_dict(params)

    def on_trial_result(self, trial_id: str, result: Dict):
        metric = result[self.metric]
        step = result[TRAINING_ITERATION]
        ot_trial = self._ot_trials[trial_id]
        ot_trial.report(metric, step)

    def on_trial_complete(
        self, trial_id: str, result: Optional[Dict] = None, error: bool = False
    ):
        ot_trial = self._ot_trials[trial_id]

        val = result.get(self.metric, None) if result else None
        ot_trial_state = OptunaTrialState.COMPLETE
        if val is None:
            if error:
                ot_trial_state = OptunaTrialState.FAIL
            else:
                ot_trial_state = OptunaTrialState.PRUNED
        try:
            self._ot_study.tell(ot_trial, val, state=ot_trial_state)
        except ValueError as exc:
            logger.warning(exc)  # E.g. if NaN was reported

    def add_evaluated_point(
        self,
        parameters: Dict,
        value: float,
        error: bool = False,
        pruned: bool = False,
        intermediate_values: Optional[List[float]] = None,
    ):
        if not self._space:
            raise RuntimeError(
                UNDEFINED_SEARCH_SPACE.format(
                    cls=self.__class__.__name__, space="space"
                )
            )
        if not self._metric or not self._mode:
            raise RuntimeError(
                UNDEFINED_METRIC_MODE.format(
                    cls=self.__class__.__name__, metric=self._metric, mode=self._mode
                )
            )

        ot_trial_state = OptunaTrialState.COMPLETE
        if error:
            ot_trial_state = OptunaTrialState.FAIL
        elif pruned:
            ot_trial_state = OptunaTrialState.PRUNED

        if intermediate_values:
            intermediate_values_dict = {
                i: value for i, value in enumerate(intermediate_values)
            }
        else:
            intermediate_values_dict = None

        trial = ot.trial.create_trial(
            state=ot_trial_state,
            value=value,
            params=parameters,
            distributions=self._space,
            intermediate_values=intermediate_values_dict,
        )

        self._ot_study.add_trial(trial)

    def save(self, checkpoint_path: str):
        save_object = (
            self._sampler,
            self._ot_trials,
            self._ot_study,
            self._points_to_evaluate,
            self._evaluated_rewards,
        )
        with open(checkpoint_path, "wb") as outputFile:
            pickle.dump(save_object, outputFile)

    def restore(self, checkpoint_path: str):
        with open(checkpoint_path, "rb") as inputFile:
            save_object = pickle.load(inputFile)
        if len(save_object) == 5:
            (
                self._sampler,
                self._ot_trials,
                self._ot_study,
                self._points_to_evaluate,
                self._evaluated_rewards,
            ) = save_object
        else:
            # Backwards compatibility
            (
                self._sampler,
                self._ot_trials,
                self._ot_study,
                self._points_to_evaluate,
            ) = save_object

    @staticmethod
    def convert_search_space(spec: Dict) -> Dict[str, Any]:
        resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

        if not domain_vars and not grid_vars:
            return {}

        if grid_vars:
            raise ValueError(
                "Grid search parameters cannot be automatically converted "
                "to an Optuna search space."
            )

        # Flatten and resolve again after checking for grid search.
        spec = flatten_dict(spec, prevent_delimiter=True)
        resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)

        def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution:
            quantize = None

            sampler = domain.get_sampler()
            if isinstance(sampler, Quantized):
                quantize = sampler.q
                sampler = sampler.sampler
                if isinstance(sampler, LogUniform):
                    logger.warning(
                        "Optuna does not handle quantization in loguniform "
                        "sampling. The parameter will be passed but it will "
                        "probably be ignored."
                    )

            if isinstance(domain, Float):
                if isinstance(sampler, LogUniform):
                    if quantize:
                        logger.warning(
                            "Optuna does not support both quantization and "
                            "sampling from LogUniform. Dropped quantization."
                        )
                    return ot.distributions.LogUniformDistribution(
                        domain.lower, domain.upper
                    )

                elif isinstance(sampler, Uniform):
                    if quantize:
                        return ot.distributions.DiscreteUniformDistribution(
                            domain.lower, domain.upper, quantize
                        )
                    return ot.distributions.UniformDistribution(
                        domain.lower, domain.upper
                    )

            elif isinstance(domain, Integer):
                if isinstance(sampler, LogUniform):
                    return ot.distributions.IntLogUniformDistribution(
                        domain.lower, domain.upper - 1, step=quantize or 1
                    )
                elif isinstance(sampler, Uniform):
                    # Upper bound should be inclusive for quantization and
                    # exclusive otherwise
                    return ot.distributions.IntUniformDistribution(
                        domain.lower,
                        domain.upper - int(bool(not quantize)),
                        step=quantize or 1,
                    )
            elif isinstance(domain, Categorical):
                if isinstance(sampler, Uniform):
                    return ot.distributions.CategoricalDistribution(domain.categories)

            raise ValueError(
                "Optuna search does not support parameters of type "
                "`{}` with samplers of type `{}`".format(
                    type(domain).__name__, type(domain.sampler).__name__
                )
            )

        # Parameter name is e.g. "a/b/c" for nested dicts
        values = {"/".join(path): resolve_value(domain) for path, domain in domain_vars}

        return values
