import os
import sys
from pathlib import Path

import d4rl_atari
import gym
import numpy as np
from tqdm.auto import tqdm


def fix_obs(obs):
    return np.squeeze(obs)


def fix_terms(terms):
    resets = np.zeros(len(terms), dtype=np.float32).astype(bool)
    resets[0] = True
    idxs = np.where(terms)[0]+1
    resets[idxs] = True
    return resets


def fix_actions(actions, reset, cats=18):
    """takes array of scalar actions and converts to one-hot.
    also offsets actions b/c dreamer uses pre-actions instead of post-actions"""
    ridxs = np.where(reset)[0]
    targets = actions.reshape(-1)
    one_hot = np.eye(cats)[targets]
    one_hot = np.roll(one_hot, 1)
    one_hot[ridxs] = np.zeros_like(one_hot[0])
    return one_hot



def split_into_chunks(reset, action, image, episodes_per_chunk=1):
    reset_idxs = np.where(reset)[0]
    for i in range(episodes_per_chunk, len(reset_idxs), episodes_per_chunk):
        start_idx = reset_idxs[i-episodes_per_chunk]
        end_idx = reset_idxs[i]

        chunk_resets = reset[start_idx:end_idx]
        chunk_actions = action[start_idx:end_idx]
        chunk_images = image[start_idx:end_idx]

        yield chunk_resets, chunk_actions, chunk_images

    yield reset[end_idx:], action[end_idx:], image[end_idx:]


def get_data(dataset_name):
    sys.stdout = open(os.devnull, 'w')
    env = gym.make(dataset_name)
    data = env.get_dataset()
    sys.stdout = sys.__stdout__
    return data


if __name__ == "__main__":
    # outfile = sys.argv[1]

    # env = gym.make('[GAME]-{mixed,medium,expert}-v{0, 1, 2, 3, 4}')
    # mixed 0 1 2 3 4
    # medium 0 1 2 3 4
    # expert 0 1 2
    ids = {"mixed": [0, 1, 2, 3, 4],
           "medium": [0, 1, 2, 3, 4],
           "expert": [0, 1, 2]}
    # dataset_name = 'breakout-expert-v2'
    episodes_per_chunk = 100
    pbar = tqdm(total=13)
    for level, versions in ids.items():
        for version in versions:
            dataset_name = 'breakout-{}-v{}'.format(level, version)
            data = get_data(dataset_name)

            reset = fix_terms(data["terminals"])
            action = fix_actions(data["actions"], reset)
            image = fix_obs(data["observations"])
            n_episodes = len(np.where(reset)[0])

            out_dir = f"../data/breakout/episodes/{dataset_name}/"
            Path(out_dir).mkdir(parents=True, exist_ok=True)

            chunk_gen = split_into_chunks(
                reset, action, image, episodes_per_chunk=episodes_per_chunk)

            total = n_episodes//episodes_per_chunk + 1
            for i, chunk in tqdm(enumerate(chunk_gen), leave=False, desc="chunking", position=1, total=total):
                outfile = out_dir + f"chunk-{i}.npz"
                np.savez(outfile, reset=chunk[0],
                         action=chunk[1], image=chunk[2])
            pbar.update(1)
