# Most of the functionality recycled from src/buffers/trajectory_buffer.py
import argparse
import collections
import pickle
import hydra
import h5py
import numpy as np
import json
from pathlib import Path
from tqdm import tqdm
from stable_baselines3.common.buffers import ReplayBuffer

# necessary to make buffer class visible
import sys
sys.path.insert(0, "../..")


def discount_cumsum_np(x, gamma):
    # much faster version of the above
    new_x = np.zeros_like(x)
    rev_cumsum = np.cumsum(np.flip(x, 0)) 
    new_x = np.flip(rev_cumsum * gamma ** np.arange(0, x.shape[0]), 0)
    new_x = np.ascontiguousarray(new_x).astype(np.float32)
    return new_x


def load_trajectory_dataset(path):
    assert isinstance(path, Path), "Path must be a Path object."
    if path.suffix == ".pkl":
        with open(str(path), "rb") as f:
            obj = pickle.load(f)
        if isinstance(obj, ReplayBuffer):
            trajectories = extract_trajectories_from_buffer(obj)
        else:
            trajectories = obj
    elif path.suffix == ".npz" or path.suffix == ".npy":
        obj = np.load(str(path))
        trajectories = extract_trajectories_from_npz(obj)
    else: 
        raise NotImplementedError("Only .pkl, .npz and .npy files are supported.")
    return trajectories 


def extract_trajectories_from_buffer(obj):
    pos = obj.pos if not obj.full else len(obj.observations)
    observations, next_observations, actions, rewards, dones = obj.observations[:pos], obj.next_observations[:pos], \
        obj.actions[:pos], obj.rewards[:pos], obj.dones[:pos]
    trajectories = extract_trajectories(observations, next_observations, actions, rewards, dones)
    return trajectories


def extract_trajectories_from_npz(obj):
    observations, next_observations, actions, rewards, dones = obj["observations"], obj["next_observations"],\
        obj["actions"], obj["rewards"], obj["dones"]
    trajectories = extract_trajectories(observations, next_observations, actions, rewards, dones)
    return trajectories


def extract_trajectories(observations, next_observations, actions, rewards, dones):
    trajectories = []
    trj_id = 0
    current_trj = collections.defaultdict(list)
    for s, s1, a, r, done in tqdm(zip(observations, next_observations,
                                      actions, rewards, dones),
                                    total=len(observations), desc="Extracting trajectories"):
        nans = [np.isnan(s).any(), np.isnan(s1).any(), np.isnan(a).any(), np.isnan(r)]
        if any(nans):
            print("NaNs found:", nans)
        s = s.astype(np.float32)
        s1 = s1.astype(np.float32)
        current_trj["observations"].append(s)
        current_trj["next_observations"].append(s1)
        current_trj["actions"].append(a)
        current_trj["rewards"].append(r)
        current_trj["terminals"].append(done)
        if done:
            current_trj["trj_id"] = trj_id
            trajectories.append(current_trj)
            current_trj = collections.defaultdict(list)
            trj_id += 1
    return trajectories


def save_episode(to_save, save_path, save_format="hdf5", compress=False):    
    if save_format == "hdf5":
        compress_kwargs = {"compression": "gzip", "compression_opts": 1} if compress else {}
        # compress_kwargs = compress_kwargs if compress_kwargs is not None else {}
        with h5py.File(save_path + ".hdf5", "w") as f:
            for k, v in to_save.items():
                if isinstance(v, (int, float, str, bool)):
                    # no compression
                    f.create_dataset(k, data=v)
                else: 
                    f.create_dataset(k, data=v, **compress_kwargs)
    elif save_format == "npzc": 
        np.savez_compressed(save_path, **to_save)
    elif save_format == "pkl": 
        with open(save_path + ".pkl", "wb") as f:
            pickle.dump(to_save, f)
    else: 
        np.savez(save_path, **to_save)
        
        
def prepare_trj(trj_dict): 
    observations, rewards = trj_dict["observations"], np.stack(trj_dict["rewards"]).reshape(-1)
    trj = {
        "states": np.vstack(observations) if isinstance(observations, list) else observations,
        "actions": np.vstack(trj_dict["actions"]),
        "rewards": rewards,
        "returns_to_go": discount_cumsum_np(rewards, 1),
        "dones": np.stack(trj_dict["terminals"]).reshape(-1),
        "trj_id": trj_dict["trj_id"]
    }
    return trj


def extract_array_stats(vals, prefix="", round=4):
    prefix = prefix + "_" if prefix else ""
    stats = {
        f"{prefix}min": np.min(vals).round(round),
        f"{prefix}max": np.max(vals).round(round),
        f"{prefix}mean": np.mean(vals).round(round),
        f"{prefix}std": np.std(vals).round(round),
        f"{prefix}q25": np.quantile(vals, 0.25).round(round),
        f"{prefix}q50": np.quantile(vals, 0.5).round(round),
        f"{prefix}q75": np.quantile(vals, 0.75).round(round),
        f"{prefix}q90": np.quantile(vals, 0.9).round(round),
        f"{prefix}q99": np.quantile(vals, 0.99).round(round),
    }
    return stats


def save_json_stats(epname_to_len, epname_to_total_returns, epname_to_trjid, save_dir): 
    # store episode lengths 
    ep_lens = [v for v in epname_to_len.values()]
    ep_returns = [v for v in epname_to_total_returns.values()]
    # compute and dumpy episode stats
    stats = {
        "episodes": len(epname_to_len.keys()), 
        "transitions": sum(ep_lens),
        **extract_array_stats(ep_lens, prefix="episode_len"),
        **extract_array_stats(ep_returns, prefix="episode_return"),
    }
    print(" | ".join([f"{k}: {v}" for k, v in stats.items()]))
    with open(save_dir / "stats.json", "w") as f:
        json.dump(stats, f)
    with open(save_dir / "episode_lengths.json", "w") as f:
        json.dump(epname_to_len, f)
    with open(save_dir / "episode_returns.json", "w") as f:
        json.dump(epname_to_total_returns, f)
    with open(save_dir / "episode_trjids.json", "w") as f:
        json.dump(epname_to_trjid, f)
    return stats


def convert_to_single_trjs(data_paths, save_dir, compress=False, save_format="hdf5"): 
    save_dir = Path(args.save_dir)
    save_dir.mkdir(parents=True, exist_ok=True)
    epname_to_len, epname_to_total_returns, epname_to_trjid= {}, {}, {}
    # iterate tasks
    for p in data_paths: 
        print(f"Converting dataset: {p.name}")
        trajectories = load_trajectory_dataset(p)
        save_data_dir = save_dir / p.stem
        save_data_dir.mkdir(parents=True, exist_ok=True)
        for i, trj_dict in enumerate(tqdm(trajectories, desc="Writing episodes")):
            file_name = str(i)
            save_path = str(save_data_dir / file_name)
            trj = prepare_trj(trj_dict)
            ep_len, ep_total_return = len(trj["states"]), trj["rewards"].sum()
            epname_to_len[file_name] = float(ep_len)
            epname_to_total_returns[file_name] = float(ep_total_return)
            epname_to_trjid[file_name] = trj["trj_id"]
            save_episode(trj, save_path, save_format, compress)
        # save stats for single task
        save_json_stats(epname_to_len, epname_to_total_returns, epname_to_trjid, save_data_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--save_dir", type=str, default="./")
    parser.add_argument("--data_paths", type=str, default="mt50_v2_cwnet_2M")
    parser.add_argument("--save_format", type=str, default="hdf5")
    parser.add_argument("--compress", action="store_true")
    args = parser.parse_args()
    hydra.initialize(config_path="../../configs")
    conf = hydra.compose(config_name="config",
                         overrides=["env_params=mt50_pretrain",
                                    "agent_params=cdt_pretrain",
                                    f"agent_params/data_paths={args.data_paths}"])
    conf.env_params.eval_env_names = None
    print(conf)
    data_paths = [Path(conf.agent_params.data_paths.base) / name for name in conf.agent_params.data_paths.names]
    convert_to_single_trjs(data_paths, args.save_dir, args.compress, args.save_format)
