from __future__ import annotations

import asyncio
from typing import Any, Iterator

import time
from pathlib import Path

import ray
from ConfigSpace import Configuration

from smac.runhistory import StatusType, TrialInfo, TrialValue
from smac.runner.abstract_runner import AbstractRunner
from smac.utils.logging import get_logger

__copyright__ = "Copyright 2022, automl.org"
__license__ = "3-clause BSD"


logger = get_logger(__name__)


class RayParallelRunner(AbstractRunner):
    """Interface to submit and collect a job in a distributed fashion via Ray. To reduce the amount of code
    within single-vs-parallel implementations, RayParallelRunner wraps a BaseRunner object which
    is then executed in parallel on `n_workers`.

    This class then is constructed by passing an AbstractRunner that implements
    a `run` method, and is capable of doing so in a serial fashion. Next,
    this wrapper class uses dask to initialize `N` number of AbstractRunner that actively wait of a
    TrialInfo to produce a RunInfo object.

    To be more precise, the work model is then:

    1. The intensifier dictates "what" to run (a configuration/instance/seed) via a TrialInfo object.
    2. An abstract runner takes this TrialInfo object and launches the task via
       `submit_run`. In the case of DaskParallelRunner, `n_workers` receive a pickle-object of
       `DaskParallelRunner.single_worker`, each with a `run` method coming from
       `DaskParallelRunner.single_worker.run()`
    3. TrialInfo objects are run in a distributed fashion, and their results are available locally to each worker. The
       result is collected by `iter_results` and then passed to SMBO.
    4. Exceptions are also locally available to each worker and need to be collected.

    Dask works with `Future` object which are managed via the DaskParallelRunner.client.

    Parameters
    ----------
    single_worker : AbstractRunner
        A runner to run in a distributed fashion. Will be distributed using `n_workers`.
    patience: int, default to 5
        How much to wait for workers (seconds) to be available if one fails.
    """

    def __init__(
        self,
        single_worker: AbstractRunner,
        patience: int = 5,
        ray_options: dict[str, int] | None = None,
    ):
        super().__init__(
            scenario=single_worker._scenario,
            required_arguments=single_worker._required_arguments,
        )

        if ray_options is None:
            ray_options = {"cpu": 1, "gpu": 0}
        self.ray_options = ray_options
        # The single worker to hold on to and call run on
        self._single_worker = single_worker

        # The list of futures that dask will use to indicate in progress runs
        self._pending_trials: list[ray._raylet.ObjectRef] = []

        # Dask related variables
        self._scheduler_file: Path | None = None
        self._patience = patience

        # need to keep track of how many trials are intended to be run
        # (do this to minimize code logic changes wrt dask)
        self.cnt_intend_to_run = 0


    def submit_trial(self, trial_info: TrialInfo) -> None:
        """This function submits a configuration embedded in a ``trial_info`` object, and uses one of
        the workers to produce a result locally to each worker.

        The execution of a configuration follows this procedure:

        #. The SMBO/intensifier generates a `TrialInfo`.
        #. SMBO calls `submit_trial` so that a worker launches the `trial_info`.
        #. `submit_trial` internally calls ``self.run()``. It does so via a call to `run_wrapper` which contains common
           code that any `run` method will otherwise have to implement.

        All results will be only available locally to each worker, so the main node needs to collect them.

        Parameters
        ----------
        trial_info : TrialInfo
            An object containing the configuration launched.

        """
        # # Check for resources or block till one is available
        # if not self.can_run_another_task(self.ray_options):
        #     logger.debug("No worker available. Waiting for one to be available...")
        #     _, _ = ray.wait(self._pending_trials, num_returns=1, timeout=None)
        #     self._process_pending_trials()
        # else:
        #     self.cnt_intend_to_run += 1
        #
        # # Check again to make sure that there are resources - remove this check, as it is replaced by cnt_intend_to_run

        # At this point we can submit the job
        start_signal = Signal.remote()
        trial = self._submit_ray.options(**self.ray_options).remote(self._single_worker.run_wrapper, start_signal, trial_info=trial_info)
        # Wait for the task to start
        while True:
            if ray.get(start_signal.is_set.remote()):
                break
            time.sleep(0.1)  # Sleep briefly to avoid busy waiting
        self._pending_trials.append(trial)

    def iter_results(self) -> Iterator[tuple[TrialInfo, TrialValue]]:  # noqa: D102
        self._process_pending_trials()
        while self._results_queue:
            yield self._results_queue.pop(0)

    def wait(self) -> None:  # noqa: D102
        if self.is_running():
            _, _ = ray.wait(self._pending_trials, num_returns=1, timeout=None)

    def is_running(self) -> bool:  # noqa: D102
        return len(self._pending_trials) > 0

    @staticmethod
    @ray.remote(num_cpus=1, num_gpus=0.25)
    def _submit_ray(f, start_signal, **kwargs):
        # print(f'Started a task')
        start_signal.set.remote()
        # kwargs['dask_data_to_scatter'] = {}
        out = f(**kwargs)
        # print(f'{out=}')
        return out

    def run(
        self,
        config: Configuration,
        instance: str | None = None,
        budget: float | None = None,
        seed: int | None = None,
        **dask_data_to_scatter: dict[str, Any],
    ) -> tuple[StatusType, float | list[float], float, dict]:  # noqa: D102
        return self._single_worker.run(
            config=config, instance=instance, seed=seed, budget=budget, **dask_data_to_scatter
        )

    def can_run_another_task(self, ray_options) -> bool:
        available_resources = ray.available_resources()

        can_run_cpus = available_resources.get('CPU', 0) >= ray_options['num_cpus']
        can_run_gpus = available_resources.get('GPU', 0) >= ray_options['num_gpus']

        return can_run_cpus and can_run_gpus


    def _process_pending_trials(self) -> None:
        """The completed trials are moved from ``self._pending_trials`` to ``self._results_queue``.
        """

        # Check which trials have finished
        done_refs, remaining_refs = ray.wait(self._pending_trials, num_returns=len(self._pending_trials), timeout=0)

        # Move the results of done trials to the results queue
        for done_ref in done_refs:
            result = ray.get(done_ref)
            self._results_queue.append(result)

        self._pending_trials = remaining_refs

    def count_available_workers(self, ray_options) -> int:
        """Returns the number of available workers."""
        available_resources = ray.available_resources()
        available_cpus = available_resources.get('CPU', 0)
        available_gpus = available_resources.get('GPU', 0)
        return min(available_cpus // ray_options['num_cpus'], available_gpus // ray_options['num_gpus'])

@ray.remote
class Signal:
    def __init__(self):
        self.is_set_flag = False

    def set(self):
        self.is_set_flag = True

    def is_set(self):
        return self.is_set_flag