import gzip
import os
import sys

import numpy as np
import tensorflow as tf
from tqdm import tqdm

STORE_FILENAME_PREFIX = "$store$_"
ELEMS = ["observation", "action", "reward", "episode_end", "terminal"]

GAME = sys.argv[1]
RUN = sys.argv[2]
source_data_dir = "./data/dataset"
destination_data_dir = "./data/dqn-dataset"

os.makedirs(f"{destination_data_dir}/{GAME}/{RUN}", exist_ok=True)


def _feature_description():
    return {
        "episode_idx": tf.io.FixedLenFeature([], tf.int64),
        "checkpoint_idx": tf.io.FixedLenFeature([], tf.int64),
        "episode_return": tf.io.FixedLenFeature([], tf.float32),
        "actions": tf.io.VarLenFeature(tf.int64),
        "observations": tf.io.VarLenFeature(tf.string),
        "unclipped_rewards": tf.io.VarLenFeature(tf.float32),
        "discounts": tf.io.VarLenFeature(tf.float32),
    }


def atari_example_to_rlds(example_bytes: tf.Tensor):
    data = tf.io.parse_single_example(example_bytes, _feature_description())

    actions = tf.sparse.to_dense(data["actions"])
    rewards = tf.sparse.to_dense(data["unclipped_rewards"])
    discounts = tf.sparse.to_dense(data["discounts"])
    obs_png_bytes = tf.sparse.to_dense(data["observations"], default_value=b"")

    def _decode_one(x):
        img = tf.io.decode_png(x, channels=1)  # [84,84,1], uint8
        return img

    observations = tf.map_fn(_decode_one, obs_png_bytes, dtype=tf.uint8, back_prop=False)

    T = tf.shape(actions)[0]

    is_first = tf.concat([[True], tf.zeros(T - 1, dtype=tf.bool)], axis=0)
    is_last = tf.concat([tf.zeros(T - 1, dtype=tf.bool), [True]], axis=0)
    is_term = tf.zeros_like(actions, dtype=tf.bool)

    terminal_cond = tf.equal(discounts[-1], 0.0)
    is_term = tf.cond(
        terminal_cond, lambda: tf.concat([tf.zeros(T - 1, dtype=tf.bool), [True]], axis=0), lambda: is_term
    )

    discounts = tf.cond(terminal_cond, lambda: tf.concat([discounts[1:], [0.0]], axis=0), lambda: discounts)

    return {
        "episode_id": data["episode_idx"],
        "checkpoint_id": data["checkpoint_idx"],
        "episode_return": data["episode_return"],
        "steps": {
            "observation": observations,
            "action": actions,
            "reward": rewards,
            "discount": discounts,
            "is_first": is_first,
            "is_last": is_last,
            "is_terminal": is_term,
        },
    }


def episode_generator(ds: tf.data.Dataset):
    for ep in ds:
        ep_id = ep["episode_id"].numpy()
        steps = ep["steps"]
        obs = steps["observation"].numpy()
        acts = steps["action"].numpy()
        rews = steps["reward"].numpy()
        last_flags = steps["is_last"].numpy()
        term_flags = steps["is_terminal"].numpy()

        yield ep_id, obs, acts, rews, last_flags, term_flags


for i in tqdm(range(50)):
    files = [f"{source_data_dir}/{GAME}/run_{RUN}-{i:05d}-of-00050"]
    raw_ds = tf.data.TFRecordDataset(files, compression_type="GZIP")
    ds = raw_ds.map(atari_example_to_rlds, num_parallel_calls=tf.data.AUTOTUNE)

    episodes = []
    for ep_id, obs, acts, rews, last_f, term_f in episode_generator(ds):
        episodes.append((ep_id, obs, acts, rews, last_f, term_f))

    _store = {
        "observation": np.concatenate([episodes[i][1] for i in range(len(episodes))]).squeeze(),
        "action": np.concatenate([episodes[i][2] for i in range(len(episodes))]),
        "reward": np.concatenate([episodes[i][3] for i in range(len(episodes))]),
        "episode_end": np.concatenate([episodes[i][4] for i in range(len(episodes))]),
        "terminal": np.concatenate([episodes[i][5] for i in range(len(episodes))]),
    }
    for elem in ELEMS:
        filename = f"{destination_data_dir}/{GAME}/{RUN}/{STORE_FILENAME_PREFIX}{elem}_ckpt.{i}.gz"
        with open(filename, "wb") as f:
            with gzip.GzipFile(fileobj=f) as outfile:
                np.save(outfile, _store[elem], allow_pickle=False)
