import logging
import os
from datetime import timedelta
import copy
from typing import Dict, List, Optional, Tuple, Set
from dataclasses import dataclass
import subprocess

from fastbo.report import retrieve
from fastbo.backend.trial_backend import (
    TrialAndStatusInformation,
    TrialIdAndResultList,
)
from fastbo.backend.local_backend import LocalBackend
from fastbo.util import dump_json_with_numpy
from fastbo.backend.trial_status import TrialResult, Status, Trial
from fastbo.backend.simulator_backend.time_keeper import SimulatedTimeKeeper
from fastbo.backend.simulator_backend.events import (
    SimulatorState,
    StartEvent,
    CompleteEvent,
    StopEvent,
    OnTrialResultEvent,
)
from fastbo.constants import ST_CHECKPOINT_DIR, ST_WORKER_TIMESTAMP, ST_TUNER_TIME
from fastbo.tuner import DEFAULT_SLEEP_TIME

logger = logging.getLogger(__name__)


DEFAULT_DELAY = 0.05


@dataclass
class SimulatorConfig:
    """
    Configures the simulator:

    :param delay_on_trial_result: Time from ``report`` called on worker to result
        registered at backend, defaults to :const:`DEFAULT_DELAY`
    :param delay_complete_after_final_report: Time from final ``report`` called
        on worker to job completion being registered at backend. Defaults to
        :const:`DEFAULT_DELAY`
    :param delay_complete_after_stop: Time from stop signal received at worker
        to job completion being registered at backend. Defaults to
        :const:`DEFAULT_DELAY`
    :param delay_start: Time from start command being sent at backend and job
        starting on the worker (which is free). Defaults to :const:`DEFAULT_DELAY`
    :param delay_stop: Time from stop signal being sent at backend to signal
        received at worker (which is running). Defaults to :const:`DEFAULT_DELAY`
    """

    delay_on_trial_result: float = DEFAULT_DELAY
    delay_complete_after_final_report: float = DEFAULT_DELAY
    delay_complete_after_stop: float = DEFAULT_DELAY
    delay_start: float = DEFAULT_DELAY
    delay_stop: float = DEFAULT_DELAY

    def __post_init__(self):
        assert self.delay_on_trial_result >= 0
        assert self.delay_complete_after_final_report >= 0
        assert self.delay_complete_after_stop >= 0
        assert self.delay_start >= 0
        assert self.delay_stop >= 0
        # Otherwise, the final result may arrive after the job completion
        assert self.delay_on_trial_result <= self.delay_complete_after_final_report


class SimulatorBackend(LocalBackend):
    """
    This simulator backend drives experiments with tabulated training
    evaluation functions, which return their computation time rather than
    spend it. To this end, time (on the tuning instance) is simulated using
    a :attr:`time_keeper` and an event priority queue in :attr:`_simulator_state`.

    Time is advanced both by :meth:`~syne_tune.Tuner.run` waiting, and by non-negligible
    computations during the tuning loop (in particular, we take care of
    ``scheduler.suggest`` and ``scheduler.on_trial_result`` there).

    When the ``entry_point`` script is executed, we wait for all results to
    be returned. In each result, the value for key ``elapsed_time_attr``
    contains the time since start of the script. These values are used
    to place worker events on the simulated timeline (represented by
    ``simulator_state``).
    NOTE: If a trial is resumed, the elapsed_time value contains the time
    since start of the last recent resume, NOT the cumulative time used by
    the trial.

    Each method call starts by advancing time by what was spent outside,
    since the last recent call to the backend. Then, all events in
    ``simulator_state`` are processed whose time is before the current time
    in ``time_keeper``. The method ends by ``time_keeper.mark_exit()``.

    .. note::
       In this basic version of the simulator backend, we still call a
       Python main function as a subprocess, which returns the requested
       metrics by looking them up or running a surrogate. This is flexible,
       but has the overhead of loading a table at every call. For fast and
       convenient simulations, use
       ::class:`~syne_tune.blackbox_repository.BlackboxRepositoryBackend` after
       bringing your tabulated data or surrogate benchmark into the blackbox
       repository.

    :param entry_point: Python main file to be tuned (this should
        return all results directly, and report elapsed time in the
        ``elapsed_time_attr`` field
    :param elapsed_time_attr: See above
    :param simulator_config: Parameters for simulator, optional
    :param tuner_sleep_time: Effective sleep time in
        :meth:`~syne_tune.Tuner.run`. This information is needed in
        :class:`~syne_tune.backend.simulator_backend.SimulatorCallback`.
        Defaults to :const:`~syne_tune.tuner.DEFAULT_SLEEP_TIME`
    """

    def __init__(
        self,
        entry_point: str,
        elapsed_time_attr: str,
        simulator_config: Optional[SimulatorConfig] = None,
        tuner_sleep_time: float = DEFAULT_SLEEP_TIME,
        debug_resource_attr: Optional[str] = None,
    ):
        super().__init__(entry_point=entry_point, rotate_gpus=False)
        self.elapsed_time_attr = elapsed_time_attr
        if simulator_config is None:
            self.simulator_config = SimulatorConfig()
        else:
            self.simulator_config = simulator_config
        self.tuner_sleep_time = tuner_sleep_time
        self._debug_resource_attr = debug_resource_attr
        # Start with empty event queue
        self._simulator_state = SimulatorState()
        self._time_keeper = SimulatedTimeKeeper()
        self._next_results_to_fetch = dict()
        self._busy_trial_ids = set()
        logger.setLevel(logging.INFO)  # Suppress DEBUG for this class

    @property
    def time_keeper(self) -> SimulatedTimeKeeper:
        return self._time_keeper

    @staticmethod
    def _debug_message(
        event_name: str, time: float, trial_id: int, pushed: bool = False, **kwargs
    ):
        msg_part = "push " if pushed else ""
        msg = f"[{msg_part}{event_name}:"
        parts = [f"time = {time:.2f}", f"trial_id = {trial_id}"] + [
            f"{k} = {v}" for k, v in kwargs.items()
        ]
        msg += ", ".join(parts) + "]"
        logger.debug(msg)

    def start_trial(
        self, config: Dict, checkpoint_trial_id: Optional[int] = None
    ) -> Trial:
        # Overwritten to record the correct ``creation_time``
        trial_id = self.new_trial_id()
        if checkpoint_trial_id is not None:
            self.copy_checkpoint(
                src_trial_id=checkpoint_trial_id, tgt_trial_id=trial_id
            )
        self.trial_ids.append(trial_id)
        self._schedule(trial_id=trial_id, config=config)
        now = self._time_keeper.time_stamp()
        trial = Trial(
            trial_id=trial_id,
            config=config,
            creation_time=now,
        )
        self._trial_dict[trial_id] = trial

        return trial

    def _process_events_until_now(self):
        """Process all events in the queue with times before ``time_keeper.time()``."""
        time_now = self._time_keeper.time()
        next_event = self._simulator_state.next_until(time_now)
        while next_event is not None:
            time_event, event = next_event
            trial_id = event.trial_id
            if isinstance(event, StartEvent):
                self._process_start_event(trial_id=trial_id, time_event=time_event)
            elif isinstance(event, CompleteEvent):
                self._process_complete_event(
                    trial_id=trial_id, time_event=time_event, status=event.status
                )
            elif isinstance(event, StopEvent):
                self._process_stop_event(trial_id=trial_id, time_event=time_event)
            elif isinstance(event, OnTrialResultEvent):
                self._process_on_trial_result_event(time_event=time_event, event=event)
            else:
                raise TypeError(f"Event at time {time_event} of unknown type: {event}")
            next_event = self._simulator_state.next_until(time_now)

    def _process_start_event(
        self, trial_id: int, time_event: float, config: Optional[dict] = None
    ):
        self._debug_message("StartEvent", time=time_event, trial_id=trial_id)
        # Run training script and record results
        status, results = self._run_job_and_collect_results(trial_id, config=config)
        time_final_result = time_event
        deb_it = 0  # DEBUG
        for i, result in enumerate(results):
            elapsed_time = result.get(self.elapsed_time_attr)
            assert elapsed_time is not None, (
                f"Result for trial_id = {trial_id} does not contain "
                + f"{self.elapsed_time_attr} entry. Your code needs "
                + "to report elapsed time, and the attribute name "
                + "must be set as elapsed_time_attr here."
            )
            _time_result = time_event + float(elapsed_time)
            time_result = _time_result + self.simulator_config.delay_on_trial_result
            self._simulator_state.push(
                OnTrialResultEvent(trial_id=trial_id, result=result),
                event_time=time_result,
            )
            time_final_result = max(time_final_result, _time_result)
            # DEBUG:
            if deb_it < 10:
                if self._debug_resource_attr:
                    k = self._debug_resource_attr
                    debug_kwargs = {k: result.get(k)}
                else:
                    debug_kwargs = dict()
                self._debug_message(
                    "OnTrialResultEvent",
                    time=time_result,
                    trial_id=trial_id,
                    pushed=True,
                    **debug_kwargs,
                )
                deb_it += 1
        time_complete = (
            time_final_result + self.simulator_config.delay_complete_after_final_report
        )
        self._simulator_state.push(
            CompleteEvent(trial_id=trial_id, status=status), event_time=time_complete
        )
        self._debug_message(
            "CompleteEvent", time=time_complete, trial_id=trial_id, pushed=True
        )
        self._busy_trial_ids.add(trial_id)

    def _process_complete_event(self, trial_id: int, time_event: float, status: str):
        self._debug_message(
            "CompleteEvent", time=time_event, trial_id=trial_id, status=status
        )
        trial_result = self._trial_dict[trial_id]
        training_end_time = self._time_keeper.start_time_stamp + timedelta(
            seconds=time_event
        )
        if isinstance(trial_result, TrialResult):
            trial_result.status = status
            trial_result.training_end_time = training_end_time
        else:
            # No results reported for the trial. This can happen if
            # the trial failed
            self._trial_dict[trial_id] = trial_result.add_results(
                metrics=[], status=status, training_end_time=training_end_time
            )
        if trial_id in self._busy_trial_ids:
            self._busy_trial_ids.remove(trial_id)

    def _process_stop_event(self, trial_id: int, time_event: float):
        self._debug_message("StopEvent", time=time_event, trial_id=trial_id)
        # Remove all remaining events for ``trial_id``. This includes
        # the ``CompleteEvent`` pushed with ``StartEvent``, so there can
        # be no confusion with the 2nd ``CompleteEvent`` pushed by
        # ``_stop_trial``.
        self._simulator_state.remove_events(trial_id)
        if trial_id in self._busy_trial_ids:
            self._busy_trial_ids.remove(trial_id)

    def _process_on_trial_result_event(
        self, time_event: float, event: OnTrialResultEvent
    ):
        trial_id = event.trial_id
        result = copy.copy(event.result)
        if self._debug_resource_attr:
            k = self._debug_resource_attr
            debug_kwargs = {k: result.get(k)}
        else:
            debug_kwargs = dict()
        self._debug_message(
            "OnTrialResultEvent",
            time=time_event,
            trial_id=trial_id,
            **debug_kwargs,
        )
        # Append timestamps to ``result``. This is done here, but not in
        # the other backends, for which timestamps are only added when
        # results are written out.
        result[ST_TUNER_TIME] = time_event
        if trial_id in self._next_results_to_fetch:
            self._next_results_to_fetch[trial_id].append(result)
        else:
            self._next_results_to_fetch[trial_id] = [result]
        trial_result = self._trial_dict[trial_id]
        if isinstance(trial_result, TrialResult):
            trial_result.metrics.append(result)
        else:
            self._trial_dict[trial_id] = trial_result.add_results(
                metrics=[result],
                status=Status.in_progress,
                training_end_time=None,
            )

    def _advance_by_outside_time(self):
        self._time_keeper.advance(self._time_keeper.real_time_since_last_recent_exit())

    def fetch_status_results(
        self, trial_ids: List[int]
    ) -> (TrialAndStatusInformation, TrialIdAndResultList):
        self._advance_by_outside_time()
        # Process all events in the past
        self._process_events_until_now()
        # Results are collected in ``_next_results_to_fetch``
        results = []
        for trial_id in trial_ids:
            result_list = self._next_results_to_fetch.get(trial_id)
            if result_list is not None:
                results.extend((trial_id, result) for result in result_list)
                self._last_metric_seen_index[trial_id] += len(result_list)
                del self._next_results_to_fetch[trial_id]
        if self._next_results_to_fetch:
            # Note: This tends to happen regularly with fast-running
            # benchmarks
            warn_msg = [
                "The following trials reported results, but are not covered "
                "by trial_ids. These results will be ignored:"
            ]
            for trial_id, result_list in self._next_results_to_fetch.items():
                status = self._trial_dict[trial_id].status
                msg_line = f"  trial_id {trial_id}: status = {status}, "
                if self._debug_resource_attr is None:
                    msg_line += f"num_results = {len(result_list)}"
                else:
                    resources = [
                        result[self._debug_resource_attr] for result in result_list
                    ]
                    msg_line += f"resources = {resources}"
                warn_msg.append(msg_line)
                self._last_metric_seen_index[trial_id] += len(result_list)
            logger.debug("\n".join(warn_msg))
            self._next_results_to_fetch = dict()

        if len(results) > 0 and ST_WORKER_TIMESTAMP in results[0]:
            results = sorted(results, key=lambda result: result[1][ST_WORKER_TIMESTAMP])

        trial_status_dict = dict()
        for trial_id in trial_ids:
            trial_result = self._trial_dict[trial_id]
            status = (
                trial_result.status
                if isinstance(trial_result, TrialResult)
                else Status.in_progress
            )
            trial = Trial(
                trial_id=trial_result.trial_id,
                config=trial_result.config,
                creation_time=trial_result.creation_time,
            )
            trial_status_dict[trial_id] = (trial, status)

        self._time_keeper.mark_exit()
        return trial_status_dict, results

    def fetch_last_trial_information(
            self, trial_ids: List[int]
    ) -> (TrialAndStatusInformation, TrialIdAndResultList):
        """
        Fetches the last trial information for trial_ids from `self._trial_dict`
        This is only used in FastBO, to finish full evaluation for some well-performing trials.

        :param trial_ids: IDs of the trials of interest, only contain the uncompleted trials
        """
        self._advance_by_outside_time()

        trial_status_dict = dict()
        results = []
        for trial_id in trial_ids:
            trial_result = self._trial_dict[trial_id]
            # Note: sometimes a trial has been marked as a running trial, but its first result
            # hasn't been added in to `self._trial_dict` because the stop condition reached.
            # If that is the case, its related `trial_result` will be `Trial`, rather than
            # `TrialResult`, and doesn't have "status" or "metrics" attribute.
            if isinstance(trial_result, TrialResult):
                last_metric = trial_result.metrics[-1]
                results.append((trial_id, last_metric))

            status = (
                trial_result.status
                if isinstance(trial_result, TrialResult)
                else Status.in_progress
            )
            trial = Trial(
                trial_id=trial_result.trial_id,
                config=trial_result.config,
                creation_time=trial_result.creation_time,
            )
            trial_status_dict[trial_id] = (trial, status)

        self._time_keeper.mark_exit()
        return trial_status_dict, results

    def _schedule(self, trial_id: int, config: Dict):
        """
        This is called by :meth:`start_trial` or :meth:`resume_trial`. We register a start
        event here. ``config`` can be ignored, it will be in
        ``trial(trial_id).config`` once the start event is executed.

        Note: This call is "non-blocking": The start event is registered
        here (in the future), but is not yet processed.

        :param trial_id: ID of trial to schedule
        :param config: Configuration for this trial
        """
        self._advance_by_outside_time()
        # Process all events in the past
        self._process_events_until_now()
        _time_start = self._time_keeper.time()
        time_start = _time_start + self.simulator_config.delay_start
        self._simulator_state.push(StartEvent(trial_id=trial_id), event_time=time_start)
        self._debug_message(
            "StartEvent", time=time_start, trial_id=trial_id, pushed=True
        )
        logger.debug(f"Simulated time since start: {_time_start:.2f} secs")
        self._time_keeper.mark_exit()

    def _all_trial_results(self, trial_ids: List[int]) -> List[TrialResult]:
        """
        Note: Since this is not used anymore in :meth:`fetch_results`, it can
        simply just return all registered trials which already have some
        results.
        This will not return trials which have just been started, but did not
        report any results yet.
        """
        results = []
        for trial_id in trial_ids:
            trial_result = self._trial_dict[trial_id]
            # Filter out entries which have not obtained any results
            if isinstance(trial_result, TrialResult):
                results.append(trial_result)
        return results

    def _pause_trial(self, trial_id: int, result: Optional[dict]):
        self._stop_or_pause_trial(trial_id, status=Status.paused)

    def _pause_all_trials(
            self, trial_ids: Set[int],
            results: TrialIdAndResultList
    ):
        # Pauses all trials together. Pausing them one by one may cause
        # unexpected runs of the subsequent events
        n_trials = len(trial_ids)
        status = Status.paused

        self._advance_by_outside_time()
        time_stop = self._time_keeper.time() + n_trials * self.simulator_config.delay_stop
        for trial_id in trial_ids:
            # omit the step of pushing a StopEvent, but still add a debug message
            self._debug_message("StopEvent", time=time_stop, trial_id=trial_id, pushed=True)
        self._time_keeper.advance_to(time_stop + 1e-3)

        # Manually changes the call for the `_process_events_until_now` method.
        # Manually handles the n StopEvents here.
        for trial_id in trial_ids:
            self._simulator_state.remove_events(trial_id)
            if trial_id in self._busy_trial_ids:
                self._busy_trial_ids.remove(trial_id)
        time_complete = (
                self._time_keeper.time() +
                n_trials * self.simulator_config.delay_complete_after_stop
        )
        for trial_id in trial_ids:
            self._simulator_state.push(
                CompleteEvent(trial_id=trial_id, status=status), event_time=time_complete
            )
            self._debug_message(
                "CompleteEvent", time=time_complete, trial_id=trial_id, pushed=True
            )
        # Process final ``CompleteEvent``
        self._time_keeper.advance_to(time_complete + 1e-3)
        self._process_events_until_now()
        self._time_keeper.mark_exit()

    def _resume_trial(self, trial_id: int):
        pass

    def _stop_trial(self, trial_id: int, result: Optional[dict]):
        self._stop_or_pause_trial(trial_id, status=Status.stopped)

    def _stop_or_pause_trial(self, trial_id: int, status: str):
        """
        This is called by :meth:`stop_trial`` or :meth:`pause_trial``.

        Note: This call is "blocking": Stop and complete events
        are not just registered here, but also processed.

        :param trial_id: ID of trial to stop or pause
        :param status: ``Status.paused`` or ``Status.stopped``
        """
        self._advance_by_outside_time()
        time_stop = self._time_keeper.time() + self.simulator_config.delay_stop
        self._simulator_state.push(StopEvent(trial_id=trial_id), event_time=time_stop)
        self._debug_message("StopEvent", time=time_stop, trial_id=trial_id, pushed=True)
        # Note: We need to call ``_process_events_until_now`` twice. If we first
        # pushed the final ``CompleteEvent``, it would be removed by the
        # ``StopEvent``.
        self._time_keeper.advance_to(time_stop + 1e-3)
        # Process events up to and including ``StopEvent``
        self._process_events_until_now()
        time_complete = (
            self._time_keeper.time() + self.simulator_config.delay_complete_after_stop
        )
        self._simulator_state.push(
            CompleteEvent(trial_id=trial_id, status=status), event_time=time_complete
        )
        self._debug_message(
            "CompleteEvent", time=time_complete, trial_id=trial_id, pushed=True
        )
        # Process final ``CompleteEvent``
        self._time_keeper.advance_to(time_complete + 1e-3)
        self._process_events_until_now()
        self._time_keeper.mark_exit()

    def _run_job_and_collect_results(
        self, trial_id: int, config: Optional[dict] = None
    ) -> (str, List[dict]):
        """
        Runs training evaluation script for trial ``trial_id``, using the config
        ``trial(trial_id).config``. This is a blocking call, we wait for the
        script to finish, then parse all its results and return them.

        :param trial_id: ID for trial to be run
        :param config: Configuration, defaults to config of ``trial_id``
        :return: ``(final_status, list_of_results_reported)``
        """
        assert (
            trial_id in self._trial_dict
        ), f"Trial with trial_id = {trial_id} not registered with backend"
        if config is None:
            config = self._trial_dict[trial_id].config

        # Run training script and fetch all results
        trial_path = self.trial_path(trial_id)
        os.makedirs(trial_path, exist_ok=True)
        config_copy = config.copy()
        config_copy[ST_CHECKPOINT_DIR] = str(self.checkpoint_trial_path(trial_id))
        config_str = " ".join(
            [f"--{key} {value}" for key, value in config_copy.items()]
        )

        dump_json_with_numpy(config, trial_path / "config.json")
        cmd = f"python {self.entry_point} {config_str}"
        env = dict(os.environ)
        logging.info(f"running script with command: {cmd}")
        with open(trial_path / "std.out", "a") as stdout:
            with open(trial_path / "std.err", "a") as stderr:
                return_status = subprocess.run(
                    cmd.split(" "), stdout=stdout, stderr=stderr, env=env
                )
        if return_status.returncode == 0:
            status = Status.completed
        else:
            status = Status.failed
        # Read all reported results
        # Results are also read if the process failed
        # Note that ``retrieve`` returns all results, even those already
        # received before (in case the trial is resumed at least once).
        all_results = retrieve(log_lines=self.stdout(trial_id=trial_id))
        num_already_before = self._last_metric_seen_index[trial_id]
        assert num_already_before <= len(all_results), (
            f"Found {len(all_results)} total results, but have already "
            + f"processed {num_already_before} before!"
        )
        results = all_results[num_already_before:]
        return status, results

    def busy_trial_ids(self) -> List[Tuple[int, str]]:
        self._process_events_until_now()
        return [(trial_id, Status.in_progress) for trial_id in self._busy_trial_ids]
