import collections
import os
from typing import NamedTuple

import numpy as np
import rlds
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import logging
from bsuite.environments.cartpole import Cartpole
from bsuite.environments.catch import Catch
from bsuite.environments.mountain_car import MountainCar

from rosmo.buffer.dataset_buffer import UniformBuffer
from rosmo.types import Array


class ActorOutput(NamedTuple):
  observation: Array
  reward: Array
  is_first: Array
  is_last: Array
  action: Array
  discount: Array


_ENV_FACTORY = {
  "cartpole": [Cartpole, 1000],
  "catch": [Catch, 2000],
  "mountain_car": [MountainCar, 500],
}

_LOAD_SIZE = 1e7

SCORES = {
  "cartpole": {
    "random": 64.833,
    "online": 1001.0,
  },
  "catch": {
    "random": -0.667,
    "online": 1.0,
  },
  "mountain_car": {
    "random": -1000.0,
    "online": -102.167,
  },
}


def create_bsuite_ds_loader(
  env_name: str, dataset_name: str, dataset_percentage: int
):
  dataset = tfds.builder_from_directory(dataset_name).as_dataset(split="all")
  num_trajectory = _ENV_FACTORY[env_name][1]
  if dataset_percentage < 100:
    idx = np.arange(0, num_trajectory, (100 // dataset_percentage))
    idx += np.random.randint(0, 100 // dataset_percentage, idx.shape) + 1
    idx = tf.convert_to_tensor(idx, "int32")
    filter_fn = lambda episode: tf.math.equal(
      tf.reduce_sum(tf.cast(episode["episode_id"] == idx, "int32")), 1
    )
    dataset = dataset.filter(filter_fn)
  parse_fn = lambda episode: episode[rlds.STEPS]
  dataset = dataset.interleave(
    parse_fn,
    cycle_length=1,
    block_length=1,
    deterministic=False,
    num_parallel_calls=tf.data.AUTOTUNE,
  )
  return dataset


def env_loader(
  env_name: str,
  dataset_dir: str,
  data_percentage: str,
  batch_size: int = 8,
  trajectory_length: int = 1,
  noisy_obs: bool = False,
  noise_scale: float = 0.,
):
  if noisy_obs:
    assert env_name in _ENV_FACTORY.keys(
    ), "noisy obs should use original data"
  data_name = env_name
  if env_name not in _ENV_FACTORY.keys():
    _env_setting = env_name.split("_")
    if len(_env_setting) > 1:
      env_name = "_".join(_env_setting[:-1])
    if "cartpole" in env_name:
      env_name = "cartpole"
    elif "catch" in env_name:
      env_name = "catch"
    elif "mountain_car" in env_name:
      env_name = "mountain_car"
  assert env_name in _ENV_FACTORY.keys(), f"env {env_name} not supported"

  if "obs_noise" in data_name:
    logging.info("Use rebuttal directory")
    dataset_dir = dataset_dir.replace("bsuite-v2", "bsuite-rebuttal")

  dataset_name = os.path.join(dataset_dir, f"{data_name}")
  ds = create_bsuite_ds_loader(env_name, dataset_name, data_percentage)
  dl = ds.batch(int(_LOAD_SIZE)).as_numpy_iterator()
  data = next(dl)

  data_buffer = collections.defaultdict(np.ndarray)
  print(data.keys())
  if noisy_obs:
    logging.info(f"Use online observation noise: {noise_scale}")
    assert noise_scale > 0
    noise = np.random.normal(1., noise_scale, data["observation"].shape)
    data_buffer["observation"] = data["observation"] * noise
  else:
    data_buffer["observation"] = data["observation"]
  data_buffer["reward"] = data["reward"]
  data_buffer["is_first"] = data["is_first"]
  data_buffer["is_last"] = data["is_last"]
  data_buffer["action"] = data["action"]
  data_buffer["discount"] = data["discount"]

  timesteps = ActorOutput(**data_buffer)
  data_size = len(timesteps.reward)
  assert data_size < _LOAD_SIZE

  iterator = UniformBuffer(
    0,
    data_size,
    trajectory_length,
    batch_size,
  )
  logging.info(f"[Data] {data_size} transitions totally.")
  iterator.init_storage(timesteps)
  return _ENV_FACTORY[env_name][0](), iterator


if __name__ == "__main__":
  from acme.specs import make_environment_spec
  logging.set_verbosity("info")

  env, dataloader = env_loader(
    "catch",
    "/mnt_central/datasets/rl_unplugged/tensorflow_datasets/bsuite-v1",
    data_percentage=10,
    trajectory_length=10,
  )
  data = next(dataloader)
  print(data.observation.shape)
  print(f"Env: {make_environment_spec(env)}")
  print(list(zip(data.observation[0], data.reward[0])))
