# !
#  * Copyright (c) Microsoft Corporation. All rights reserved.
#  * Licensed under the MIT License. See LICENSE file in the
#  * project root for license information.
from typing import Optional

# try:
#     from ray import __version__ as ray_version
#     assert ray_version >= '1.0.0'
#     from ray.tune.trial import Trial
# except (ImportError, AssertionError):
from .trial import Trial
import logging

logger = logging.getLogger(__name__)


class Nologger:
    """Logger without logging."""

    def on_result(self, result):
        pass


class SimpleTrial(Trial):
    """A simple trial class."""

    def __init__(self, config, trial_id=None):
        self.trial_id = Trial.generate_id() if trial_id is None else trial_id
        self.config = config or {}
        self.status = Trial.PENDING
        self.start_time = None
        self.last_result = {}
        self.last_update_time = -float("inf")
        self.custom_trial_name = None
        self.trainable_name = "trainable"
        self.experiment_tag = "exp"
        self.verbose = False
        self.result_logger = Nologger()
        self.metric_analysis = {}
        self.n_steps = [5, 10]
        self.metric_n_steps = {}


class BaseTrialRunner:
    """Implementation of a simple trial runner.

    Note that the caller usually should not mutate trial state directly.
    """

    def __init__(
        self,
        search_alg=None,
        scheduler=None,
        metric: Optional[str] = None,
        mode: Optional[str] = "min",
    ):
        self._search_alg = search_alg
        self._scheduler_alg = scheduler
        self._trials = []
        self._metric = metric
        self._mode = mode

    def get_trials(self):
        """Returns the list of trials managed by this TrialRunner.

        Note that the caller usually should not mutate trial state directly.
        """
        return self._trials

    def add_trial(self, trial):
        """Adds a new trial to this TrialRunner.

        Trials may be added at any time.

        Args:
            trial (Trial): Trial to queue.
        """
        self._trials.append(trial)
        if self._scheduler_alg:
            self._scheduler_alg.on_trial_add(self, trial)

    def process_trial_result(self, trial, result):
        trial.update_last_result(result)
        if "time_total_s" not in result.keys():
            result["time_total_s"] = trial.last_update_time - trial.start_time
        self._search_alg.on_trial_result(trial.trial_id, result)
        if self._scheduler_alg:
            decision = self._scheduler_alg.on_trial_result(self, trial, result)
            if decision == "STOP":
                trial.set_status(Trial.TERMINATED)
            elif decision == "PAUSE":
                trial.set_status(Trial.PAUSED)

    def stop_trial(self, trial):
        """Stops trial."""
        if trial.status not in [Trial.ERROR, Trial.TERMINATED]:
            if self._scheduler_alg:
                self._scheduler_alg.on_trial_complete(
                    self, trial.trial_id, trial.last_result
                )
            self._search_alg.on_trial_complete(trial.trial_id, trial.last_result)
            trial.set_status(Trial.TERMINATED)
        elif self._scheduler_alg:
            self._scheduler_alg.on_trial_remove(self, trial)
            if trial.status == Trial.ERROR:
                self._search_alg.on_trial_complete(
                    trial.trial_id, trial.last_result, error=True
                )


class SequentialTrialRunner(BaseTrialRunner):
    """Implementation of the sequential trial runner."""

    def step(self) -> Trial:
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().

        Returns:
            a trial to run.
        """
        trial_id = Trial.generate_id()
        config = self._search_alg.suggest(trial_id)
        if config is not None:
            trial = SimpleTrial(config, trial_id)
            self.add_trial(trial)
            trial.set_status(Trial.RUNNING)
        else:
            trial = None
        self.running_trial = trial
        return trial

    def stop_trial(self, trial):
        super().stop_trial(trial)
        self.running_trial = None


class SparkTrialRunner(BaseTrialRunner):
    """Implementation of the spark trial runner."""

    def __init__(
        self,
        search_alg=None,
        scheduler=None,
        metric: Optional[str] = None,
        mode: Optional[str] = "min",
    ):
        super().__init__(search_alg, scheduler, metric, mode)
        self.running_trials = []

    def step(self) -> Trial:
        """Runs one step of the trial event loop.

        Callers should typically run this method repeatedly in a loop. They
        may inspect or modify the runner's state in between calls to step().

        Returns:
            a trial to run.
        """
        trial_id = Trial.generate_id()
        config = self._search_alg.suggest(trial_id)
        if config is not None:
            trial = SimpleTrial(config, trial_id)
            self.add_trial(trial)
            trial.set_status(Trial.RUNNING)
            self.running_trials.append(trial)
        else:
            trial = None
        return trial

    def stop_trial(self, trial):
        super().stop_trial(trial)
        self.running_trials.remove(trial)
