from typing import Any, Literal

import optuna

from ..types import KWArgs


def _sample_value(
    trial: optuna.trial.Trial,
    distribution: Literal['int', 'uniform', 'loguniform', 'categorical'],
    label: str,
    *args,
):
    trial_suggest, kwargs = {
        'int': (trial.suggest_int, {}),
        'uniform': (trial.suggest_float, {}),
        'loguniform': (trial.suggest_float, {'log': True}),
        'categorical': (trial.suggest_categorical, {}),
    }[distribution]
    if distribution in ('int', 'uniform', 'loguniform') and len(args) == 3:
        args, kwargs['step'] = args[:2], args[2]
    return trial_suggest(label, *args, **kwargs)


def _sample_from_space(
    trial: optuna.trial.Trial,
    space: bool | int | float | str | bytes | list | dict,
    label_parts: list,
) -> Any:
    if isinstance(space, bool | int | float | str | bytes):
        # This is a constant value, nothing to sample from.
        return space

    elif isinstance(space, list):
        if space and space[0] == '_tune_':
            # space: ["_tune_", distribution, arg_0, arg_1, ...]
            _, distribution, *args = space
            label = '.'.join(map(str, label_parts))

            # At this point, `distribution` can be one of the following:
            # 1. One of the built-in Optuna distributions expected in `_sample_value`.
            # 2. Same as 1., but prefixed with "?".
            # 3. Custom distributions. By convention, they start with "$".

            if distribution.startswith('?'):
                # space: ["_tune_", "?distribution", default_value, *actual_args]
                default, args_ = args[0], args[1:]
                if trial.suggest_categorical(f'?{label}', [False, True]):
                    return _sample_value(trial, distribution.lstrip('?'), label, *args_)
                else:
                    return default

            elif distribution == '$list':
                # space: ["_tune_", "$list", size, distribution, *actual_args]
                # A list of hyperparameters of a fixed size. For example,
                # this can be useful if a model allows configuring some hyperparameter
                # separately for each feature. Then, `size` is the number of features.
                size, item_distribution, *item_args = args
                return [
                    _sample_value(trial, item_distribution, f'{label}.{i}', *item_args)
                    for i in range(size)
                ]

            else:
                return _sample_value(trial, distribution, label, *args)

        else:
            return [
                _sample_from_space(trial, subspace, [*label_parts, i])
                for i, subspace in enumerate(space)
            ]

    elif isinstance(space, dict):
        if '_tune_' in space:
            # A custom sampling rule of any complexity. For example, in config:
            #
            # [space.my_model]
            # _tune_ = "$hyperparameter-distribution-for-my-model"
            # a = 0    # <-- any key and value
            # b = 1.0  # <-- any key and value
            # c = '2'  # <-- any key and value
            distribution = space['_tune_']
            if distribution == '$hyperparameter-distribution-for-my-model':
                assert space['a'] == 0
                assert space['b'] == 1.0
                assert space['c'] == '2'
                raise NotImplementedError()
            else:
                raise ValueError(f'Unknown distibution: "{distribution}"')

        else:
            return {
                key: _sample_from_space(trial, subspace, [*label_parts, key])
                for key, subspace in space.items()
            }


class HyperparameterSampler:
    """A simpler wrapper around `optuna.study.Study`."""

    _WARMUP_SEED_SHIFT = 1000
    _BASIC_SAMPLERS = (
        'BruteForceSampler',
        'GridSampler',
        'RandomSampler',
        'QMCSampler',
    )

    @staticmethod
    def _make_sampler(type: str, **kwargs) -> optuna.samplers.BaseSampler:
        return getattr(optuna.samplers, type)(**kwargs)

    def __init__(
        self,
        *,
        space: dict[str, Any],
        type: str = 'TPESampler',
        strict_n_startup_trials: bool = False,
        study_kwargs: None | KWArgs = None,
        **kwargs,
    ) -> None:
        self._study = optuna.create_study(
            sampler=HyperparameterSampler._make_sampler(type, **kwargs),
            **({} if study_kwargs is None else study_kwargs),
        )
        self._strict_n_startup_trials = strict_n_startup_trials
        self._space = space
        self._trials: dict[int, optuna.trial.Trial] = {}

    @property
    def n_startup_trials(self) -> None | int:
        return getattr(self._study.sampler, '_n_startup_trials', None)

    def _load_n_finished_trials(self) -> int:
        return len(
            self._study.get_trials(
                False,
                (optuna.trial.TrialState.COMPLETE, optuna.trial.TrialState.PRUNED),
            )
        )

    def ask(self, index: int) -> dict[str, Any]:
        assert index not in self._trials
        trial = self._study.ask()
        if self.n_startup_trials is not None:
            trial.set_user_attr(
                'startup', self._load_n_finished_trials() < self.n_startup_trials
            )
        self._trials[index] = trial
        return _sample_from_space(trial, self._space, [])

    def tell(self, index: int, value: float) -> bool:
        trial = self._trials[index]
        should_skip = (
            self._strict_n_startup_trials
            and self.n_startup_trials is not None
            and trial.user_attrs.get('startup', False)
            and self._load_n_finished_trials() >= self.n_startup_trials
        )
        should_tell = not should_skip
        if should_tell:
            self._study.tell(self._trials[index], value)
        return should_tell
