"""Environment loop extensions."""
import time
from abc import abstractmethod
from typing import Dict, Optional, Sequence, Union

import acme
import dm_env
import numpy as np
from acme import core
from acme.utils import counting, loggers
from acme.utils import observers as observers_lib
from acme.utils import signals
from importlib_metadata import collections

Number = Union[int, float]


class EvaluationLoop(acme.EnvironmentLoop):

  def __init__(
    self,
    environment: dm_env.Environment,
    actor: core.Actor,
    counter: Optional[counting.Counter] = None,
    logger: Optional[loggers.Logger] = None,
    should_update: bool = True,
    label: str = "environment_loop",
    observers: Sequence[observers_lib.EnvLoopObserver] = ...,
  ):
    super().__init__(
      environment, actor, counter, logger, should_update, label, observers
    )

  def run(
    self, num_episodes: Optional[int] = None, num_steps: Optional[int] = None
  ):

    if not (num_episodes is None or num_steps is None):
      raise ValueError('Either "num_episodes" or "num_steps" should be None.')

    def should_terminate(episode_count: int, step_count: int) -> bool:
      return (num_episodes is not None and episode_count >= num_episodes
             ) or (num_steps is not None and step_count >= num_steps)

    episode_count, step_count = 0, 0
    all_results: Dict[str, list] = collections.defaultdict(list)
    with signals.runtime_terminator():
      while not should_terminate(episode_count, step_count):
        result = self.run_episode()
        episode_count += 1
        step_count += result["episode_length"]
        for k, v in result.items():
          all_results[k].append(v)
      # Log the averaged results from all episodes.
      self._logger.write({k: np.mean(v) for k, v in all_results.items()})


class ExtendedEnvLoopObserver(observers_lib.EnvLoopObserver):

  @abstractmethod
  def step(self) -> None:
    """Steps the observer."""

  @abstractmethod
  def restore(self) -> None:
    """Restore the observer state."""


class LearningStepObserver(ExtendedEnvLoopObserver):

  def __init__(self) -> None:
    super().__init__()
    self._learning_step = 0
    self._eval_step = 0
    self._status = 1  # {0: train, 1: eval}
    self._train_elapsed = 0
    self._last_time = None

  def step(self) -> None:
    """Steps the observer."""
    self._learning_step += 1

    if self._status == 0:
      self._train_elapsed += time.time() - self._last_time
    if self._status == 1:
      self._status = 0

    self._last_time = time.time()

  def observe_first(
    self, env: dm_env.Environment, timestep: dm_env.TimeStep
  ) -> None:
    """Observes the initial state, setting states."""
    self._status = 1
    self._eval_step += 1

  def observe(
    self, env: dm_env.Environment, timestep: dm_env.TimeStep,
    action: np.ndarray
  ) -> None:
    """Records one environment step, dummy."""

  def get_metrics(self) -> Dict[str, Number]:
    """Returns metrics collected for the current episode."""
    return {
      "step": self._learning_step,
      "eval_step": self._eval_step,
      "learning_time": self._train_elapsed,
    }

  def restore(self, learning_step: int):
    self._learning_step = learning_step
