"""A simple binary to run catch for a while and record its trajectories.
"""

import operator
import os
import time

import acme
import dm_env
import envlogger
import numpy as np
import sonnet as snt
import tensorflow as tf
import tensorflow_datasets as tfds
import termcolor
import tree
from absl import app, flags, logging
from acme import specs, wrappers
from acme.agents.tf import dqn
from acme.utils import loggers
from bsuite.logging import csv_logging
from bsuite.utils import wrappers as bsuite_wrappers
from dm_env import specs as dm_specs
from envlogger.backends import tfds_backend_writer

from rosmo.data.bsuite.env import Cartpole, Catch, MountainCar

FLAGS = flags.FLAGS

flags.DEFINE_enum(
  "env", None, ["cartpole", "catch", "mountain_car"], "Environment"
)
flags.DEFINE_integer("num_episodes", 10, "Number of episodes to log.")
flags.DEFINE_float("action_noise", 0.0, "Random action probability noise.")
flags.DEFINE_float("model_noise", 0.0, "Random model probability noise.")
flags.DEFINE_string(
  "results_dir", None, "Directory for record trajectories and csv result."
)


def load_and_record_to_csv(
  raw_env: dm_env.Environment,
  results_dir: str,
  overwrite: bool = False
) -> dm_env.Environment:
  termcolor.cprint(
    f"Logging results to CSV file for each bsuite_id in {results_dir}.",
    color="yellow",
    attrs=["bold"],
  )
  logger = csv_logging.Logger(raw_env.bsuite_id, results_dir, overwrite)
  return bsuite_wrappers.Logging(
    raw_env, logger, log_by_step=False, log_every=True
  )


class NoiseEnvironmentLoop(acme.EnvironmentLoop):

  def __init__(self, noise: float, *args, **kwargs) -> None:
    super().__init__(*args, **kwargs)
    self.noise = noise

  def run_episode(self) -> loggers.LoggingData:
    """Run one episode.

        Each episode is a loop which interacts first with the environment to get an
        observation and then give that observation to the agent in order to retrieve
        an action.

        Returns:
          An instance of `loggers.LoggingData`.
        """
    # Reset any counts and start the environment.
    start_time = time.time()
    episode_steps = 0

    def _generate_zeros_from_spec(spec: dm_specs.Array) -> np.ndarray:
      return np.zeros(spec.shape, spec.dtype)

    # For evaluation, this keeps track of the total undiscounted reward
    # accumulated during the episode.
    episode_return = tree.map_structure(
      _generate_zeros_from_spec, self._environment.reward_spec()
    )
    timestep = self._environment.reset()

    # Make the first observation.
    self._actor.observe_first(timestep)

    # Run an episode.
    while not timestep.last():
      # Generate an action from the agent's policy and step the environment.
      if np.random.rand() < self.noise:
        _num = self._environment.action_spec().num_values
        action = np.random.randint(low=0, high=_num, dtype=np.int32)
      else:
        action = self._actor.select_action(timestep.observation)
      timestep = self._environment.step(action)

      # Have the agent observe the timestep and let the actor update itself.
      self._actor.observe(action, next_timestep=timestep)
      if self._should_update:
        self._actor.update()

      # Book-keeping.
      episode_steps += 1

      # Equivalent to: episode_return += timestep.reward
      # We capture the return value because if timestep.reward is a JAX
      # DeviceArray, episode_return will not be mutated in-place. (In all other
      # cases, the returned episode_return will be the same object as the
      # argument episode_return.)
      episode_return = tree.map_structure(
        operator.iadd, episode_return, timestep.reward
      )

    # Record counts.
    counts = self._counter.increment(episodes=1, steps=episode_steps)

    # Collect the results and combine with counts.
    steps_per_second = episode_steps / (time.time() - start_time)
    result = {
      "episode_length": episode_steps,
      "episode_return": episode_return,
      "steps_per_second": steps_per_second,
    }
    result.update(counts)
    return result


class NoiseEnvironmentLogger(envlogger.EnvLogger):

  def step(self, action):
    if self._reset_next_step:
      return self.reset()

    if np.random.rand() < FLAGS.model_noise:
      _num = self._environment.action_spec().num_values
      _action = np.random.randint(low=0, high=_num)
      timestep = self._environment.step(_action)
    else:
      timestep = self._environment.step(action)
    self._reset_next_step = timestep.last()
    data = self._transform_step(timestep, action)
    self._backend.record_step(data, is_new_episode=False)
    if self._episode_fn is not None:
      episode_metadata = self._episode_fn(timestep, action, self._environment)
      if episode_metadata is not None:
        self._backend.set_episode_metadata(episode_metadata)
    return timestep


def main(unused_argv):
  assert FLAGS.env is not None
  assert FLAGS.results_dir is not None
  if not os.path.exists(FLAGS.results_dir):
    os.makedirs(FLAGS.results_dir)

  env_dict = {
    "cartpole": Cartpole(),
    "catch": Catch(),
    "mountain_car": MountainCar()
  }

  env = load_and_record_to_csv(
    raw_env=env_dict[FLAGS.env],
    results_dir=FLAGS.results_dir,
    overwrite=False,
  )
  env = wrappers.SinglePrecisionWrapper(env)
  env_spec = specs.make_environment_spec(env)

  def episode_fn(timestep, unused_action, env):
    if timestep.last:
      return {
        "episode_id": np.array(env.episode_id, np.int32),
        "episode_return": np.array(env.episode_return, np.float32),
      }
    else:
      return None

  dataset_config = tfds.rlds.rlds_base.DatasetConfig(
    name="catch",
    observation_info=tfds.features.Tensor(
      shape=env_spec.observations.shape,
      dtype=tf.float32,
      encoding=tfds.features.Encoding.ZLIB,
    ),
    action_info=tf.int32,
    reward_info=tf.float32,
    discount_info=tf.float32,
    episode_metadata_info={
      "episode_id": tf.int32,
      "episode_return": tf.float32
    },
  )

  logging.info("Wrapping environment with EnvironmentLogger...")
  with NoiseEnvironmentLogger(
    env,
    episode_fn=episode_fn,
    backend=tfds_backend_writer.TFDSBackendWriter(
      data_directory=FLAGS.results_dir,
      split_name="full",
      max_episodes_per_file=100,
      ds_config=dataset_config,
    ),
  ) as env:
    logging.info("Done wrapping environment with EnvironmentLogger.")

    logging.info("Training a dqn agent for %r episodes...", FLAGS.num_episodes)

    network = snt.Sequential(
      [snt.Flatten(),
       snt.nets.MLP([50, 50, env_spec.actions.num_values])]
    )
    # Construct the agent.
    agent = dqn.DQN(environment_spec=env_spec, network=network)

    # Run the environment loop.
    # loop = acme.EnvironmentLoop(env, agent)
    loop = NoiseEnvironmentLoop(FLAGS.action_noise, env, agent)
    loop.run(num_episodes=FLAGS.num_episodes)  # pytype: disable=attribute-error

    logging.info(
      "Done training a dqn agent for %r episodes.", FLAGS.num_episodes
    )


if __name__ == "__main__":
  app.run(main)
