import argparse
import pickle

from itertools import tee
from pathlib import Path

import h5py
import numpy as np

from transfer.envs.metaworld import (
    MT50,
    get_single_env,
)


def pairwise(iterable):
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)


def process_h5_datasets(datasets_path: Path, output_path: Path):
    for dataset_path in datasets_path.iterdir():
        dataset = h5py.File(dataset_path, "r")
        dataset = {key: np.array(dataset[key]) for key in dataset.keys()}

        dataset_name = dataset_path.stem
        env_name = [
            task for task in list(MT50.train_classes.keys()) + list(MT50.test_classes.keys()) if task in dataset_name
        ][0]

        try:
            # test if env can be loaded
            get_single_env(env_name)
        except Exception as e:
            print(f"Env {env_name} could not be loaded. Error: {e}")

        done_idxs = dataset["done_idxs"].astype(int)

        # workaround for concatenated datasets
        split_idxs = np.where(list(map(lambda t: t[0] > t[1], pairwise(done_idxs))))[0]
        chunks = list(map(len, np.split(done_idxs, split_idxs + 1)))
        offsets = np.insert(np.cumsum(done_idxs[split_idxs]), 0, 0)
        done_idxs += np.repeat(offsets, chunks)

        start_index = 0
        paths = []
        for end_index in done_idxs:
            terminals = np.zeros(end_index - start_index).astype(bool)
            terminals[-1] = True
            # :-1 because we want to get rid of one hot
            episode_data = {
                "observations": dataset["states"][:, :-1][start_index:end_index],
                "actions": dataset["actions"][start_index:end_index],
                "rewards": dataset["stepwise_returns"][start_index:end_index],
                "terminals": terminals,
            }
            paths.append(episode_data)
            start_index = end_index

        returns = np.array([np.sum(p["rewards"]) for p in paths])
        num_samples = np.sum([p["rewards"].shape[0] for p in paths])
        assert np.mean(dataset["returns"]) - np.mean(returns) < 0.01
        print(f"{env_name}")
        print(f"Number of samples collected: {num_samples}")
        print(
            f"Trajectory returns: mean = {np.mean(returns)}, std = {np.std(returns)}, max = {np.max(returns)}, min = {np.min(returns)}"
        )

        output_path.mkdir(parents=True, exist_ok=True)
        with open(output_path / f"{env_name}.pkl", "wb") as f:
            pickle.dump(paths, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datasets_path", type=Path)
    parser.add_argument("--output_path", type=Path)
    args = parser.parse_args()
    process_h5_datasets(**vars(args))
