import functools

import bsuite
import numpy as np
import tensorflow as tf


def _parse_seq_tf_example(example, shapes, dtypes):
  """Parse tf.Example containing one or two episode steps."""

  def to_feature(shape, dtype):
    if np.issubdtype(dtype, np.floating):
      return tf.io.FixedLenSequenceFeature(
        shape=shape, dtype=tf.float32, allow_missing=True
      )
    elif dtype == np.bool or np.issubdtype(dtype, np.integer):
      return tf.io.FixedLenSequenceFeature(
        shape=shape, dtype=tf.int64, allow_missing=True
      )
    else:
      raise ValueError(
        f"Unsupported type {dtype} to "
        f"convert from TF Example."
      )

  feature_map = {}
  for k, v in shapes.items():
    feature_map[k] = to_feature(v, dtypes[k])

  parsed = tf.io.parse_single_example(example, features=feature_map)

  restructured = {}
  for k, v in parsed.items():
    dtype = tf.as_dtype(dtypes[k])
    if v.dtype == dtype:
      restructured[k] = parsed[k]
    else:
      restructured[k] = tf.cast(parsed[k], dtype)

  return restructured


def bsuite_dataset_params(env):
  """Return shapes and dtypes parameters for bsuite offline dataset."""
  shapes = {
    "observation": env.observation_spec().shape,
    "action": env.action_spec().shape,
    "discount": env.discount_spec().shape,
    "reward": env.reward_spec().shape,
    "episodic_reward": env.reward_spec().shape,
    "step_type": (),
  }

  dtypes = {
    "observation": env.observation_spec().dtype,
    "action": env.action_spec().dtype,
    "discount": env.discount_spec().dtype,
    "reward": env.reward_spec().dtype,
    "episodic_reward": env.reward_spec().dtype,
    "step_type": np.int64,
  }

  return {"shapes": shapes, "dtypes": dtypes}


tmp_path = "gs://rl_unplugged/bsuite"
level = "cartpole"
dir = "0_0.0"  # 0_0.0 - 0_0.5
filename = "0_full"  # 0_full - 4_full
path = f"{tmp_path}/{level}/{dir}/{filename}"
bsuite_id = level + "/0"

num_shards = 1
filenames = [f"{path}-{i:05d}-of-{num_shards:05d}" for i in range(num_shards)]
file_ds = tf.data.Dataset.from_tensor_slices(filenames)
example_ds = file_ds.interleave(
  functools.partial(tf.data.TFRecordDataset, compression_type="GZIP"),
  cycle_length=tf.data.experimental.AUTOTUNE,
  block_length=1,
)
name = f"{path}-{0:05d}-of-{num_shards:05d}"
example_ds = tf.data.TFRecordDataset(name, compression_type="GZIP")

env = bsuite.load_from_id(bsuite_id)
params = bsuite_dataset_params(env)
print(f"params: {params}")


def map_func(example):
  example = _parse_seq_tf_example(example, **params)
  return example


example_ds = example_ds.map(map_func, num_parallel_calls=1)
count = 0
last_obs = None
for element in example_ds.as_numpy_iterator():
  # if count >= 100:
  #   break
  # print(element)
  # for k, v in element.items():
  #   print(f"{k} shape: {v.shape}")
  obs = element["observation"]
  print(f"index: {count}")
  print(obs[0])
  print(obs[1])
  count += 1
  # print(f"frame: {count}")
  # print(f"{obs}")
  # if last_obs is not None:
  #   diff = last_obs[1] - obs[0]
  #   diff = np.abs(diff).sum()
  #   print(f"frame: {count}, diff: {diff}")
  #   if diff != 0:
  #     step_type = element["step_type"]
  #     print(f"step type: {step_type}")
  # last_obs = obs

print(f"total number: {count}")
