import os
import sys

import cv2
import numpy as np
import argparse
import tqdm

METAWORLD_PATH = ""
sys.path.append(METAWORLD_PATH)

FILE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT_DIR = os.path.dirname(FILE_DIR)
MAIN_DATA_DIR = os.path.join(ROOT_DIR, "data")

sys.path.append(ROOT_DIR)
sys.path.append(FILE_DIR)


def collect_data_pybullet():
    # Create directory for dataset
    data_dir = os.path.join(MAIN_DATA_DIR, args.dataset)
    os.makedirs(data_dir, exist_ok=True)

    # Import correct environment
    if args.env == "spiral_env":
        from modules.envs.spiral_env.spiral_env import SpiralEnv as Env
        action_noise = 0.25
    elif args.env == "obstacle_env":
        from modules.envs.obstacle_env.obstacle_env import ObstacleEnv as Env
        action_noise = 0.25

    # Create environment
    env = Env(render=False)
    args.action_dim = 2
    imgs_all = []
    gts_all = []
    actions_all = []
    context_all = []
    rewards_all = []
    reset_info = {"mode":"train", "context_params":None}

    for _ in tqdm.tqdm(range(0, args.n_contexts)):

        imgs_context = []
        gts_context = []
        actions_context = []
        rewards_context = []

        if args.env == "obstacle_env":
            reset_info["context_params"] = np.random.uniform(-0.5,0.5, 2)

            context_all.append(reset_info["context_params"])

        for _ in tqdm.tqdm(range(0, args.n_traj_per_context)):
            # Reset environment
            obs = env.reset(reset_info=reset_info)

            imgs_traj = []
            gts_traj = []
            actions_traj = []
            rewards_traj = []

            imgs_traj.append(obs["video"].copy())
            gts_traj.append(obs["state"].copy())

            action = np.random.uniform(-1.0, 1.0, args.action_dim)
            for k in range(args.max_traj_len - 1):
                action += np.random.normal(0.0, action_noise, args.action_dim)
                action = np.clip(action, -1.0, 1.0)

                obs, reward, done, info = env.step(action=action.copy())

                imgs_traj.append(obs["video"].copy().astype(np.uint8))
                gts_traj.append(obs["state"].copy().astype(np.float32))
                actions_traj.append(action.copy().astype(np.float32))
                rewards_traj.append(reward)
                if done:
                    break

            imgs_context.append(np.array(imgs_traj))
            gts_context.append(np.array(gts_traj))
            actions_context.append(np.array(actions_traj))
            rewards_context.append(np.array(rewards_traj).astype(np.float32))

        imgs_all.append(imgs_context)
        gts_all.append(gts_context)
        actions_all.append(actions_context)
        rewards_all.append(rewards_context)

    np.save(os.path.join(data_dir, "obs_all.npy"), imgs_all)
    np.save(os.path.join(data_dir, "actions_all.npy"), actions_all)
    np.save(os.path.join(data_dir, "proprioception_all.npy"), gts_all)
    np.save(os.path.join(data_dir, "rewards_all.npy"), rewards_all)
    if args.env == "obstacle_env":
        np.save(os.path.join(data_dir, "context_all.npy"), context_all)


if __name__ == '__main__':

    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--n_frames', type=int, default=3)
    parser.add_argument('--n_contexts', type=int, default=5)
    parser.add_argument('--n_traj_per_context', type=int, default=20)
    parser.add_argument('--max_traj_len', type=int, default=20)
    parser.add_argument('--render', type=int, default=0)
    parser.add_argument('--dataset', type=str, default="obstacle_env")
    parser.add_argument('--env', type=str, default="obstacle_env")

    args = parser.parse_args()

    if args.env in ("spiral_env", "obstacle_env"):
        collect_data_pybullet()
    else:
        raise NotImplementedError
