import argparse
import multiprocessing as mp
import random
import time
import warnings
import copy
from collections import defaultdict
from pathlib import Path

import gym
import metaworld
import numpy as np
from omegaconf import OmegaConf
from tqdm import tqdm
import cv2
import imageio.v2

from collect_data.metaworld.train_model import MetaWorldEnv
from collect_data.dataset_writer import DatasetWriter
import common.utils.metaworld_utils

warnings.filterwarnings("ignore")

data_dir = Path("data/metaworld")
data_dir.mkdir(exist_ok=True)



def rollout_policy(args_tuple):
    args, task_id, seed = args_tuple
    np.random.seed(seed)

    while True:
        start = time.time()

        env = gym.make(args.task)
        env.setup_task(task_id, 0)
        goal = np.array([0, 0])
        goal_id = task_id

        if "open" in args.task:
            policy = make_window_open_policy()
        elif "close" in args.task:
            policy = make_window_close_policy()

        # reset the environment to prepare for a rollout
        obs = env.reset()
        env.render(
            offscreen=True,
            camera_name="corner3",
            resolution=args.resolution,
        )
        env.render(
            offscreen=True,
            camera_name="corner",
            resolution=args.resolution,
        )
        env.step(np.zeros(4))
        env.unwrapped.t = 0

        # reset data_dict
        np_data_dict = defaultdict(list)

        done = False
        t = 0
        while not done:
            states = env.sim.get_state().flatten()
            obs = env.unwrapped.env._get_obs()

            np_data_dict["infos/states"].append(states)
            np_data_dict["observations"].append(obs)

            obs_dict = env.get_obs_dict()
            for key, val in obs_dict.items():
                np_data_dict[f"infos/{key}"].append(val)

            np_data_dict["infos/goal_id"].append(goal_id)
            np_data_dict["infos/goal"].append(goal)

            if args.image:
                camera_names = [
                    "corner3",
                    "corner",
                ]
                for camera_name in camera_names:
                    image = env.render(
                        offscreen=True,
                        camera_name=camera_name,
                        resolution=args.resolution,
                    )
                    np_data_dict[f"infos/{camera_name}_image"].append(image)

            action = policy(obs_dict)
            obs, rew, done, info = env.step(action)

            success = env.get_success()
            done |= success

            for key, val in info.items():
                np_data_dict[f"infos/{key}"].append(val)

            np_data_dict["actions"].append(action)
            np_data_dict["rewards"].append(rew)
            np_data_dict["timeouts"].append(done)

            t += 1

        end = time.time()
        if args.verbose:
            print(f"Total Frames: {t}")
            print(f"Elapsed time: {end-start:.4f} s")
            print(f"Seconds per Frame: {(end-start)/t*1000:.4f} ms/frame")
            print(f"FPS: {t/(end-start):.4f}")
            print(f"Success: {success}")
            print()

        if success:
            break

    return np_data_dict


def main():
    args = OmegaConf.create({
        "task": "window-close-v2",
        # "task": "window-open-v2",
        "n_traj": 400,
        "n_process": 24,
        "verbose": False,
        "image": True,
        "resolution": (128, 128),
        "n_visualize_episodes": 10,
    })
    args = OmegaConf.merge(args, OmegaConf.from_cli())
    print(args.task)

    writer = DatasetWriter()
    num_task_ids = 1

    frames = []
    for task_id in range(num_task_ids):
        args_list = [(
            args,
            task_id,
            np.random.randint(2**15),
        ) for _ in range(args.n_traj)]

        bar_format = f"Task ID: {task_id} " + "{l_bar}{bar:64}{r_bar}"
        start = time.time()
        pool = mp.Pool(args.n_process)
        imap = pool.imap(rollout_policy, args_list)
        outputs = list(tqdm(imap, total=args.n_traj, bar_format=bar_format))
        end = time.time()
        total_frames = sum([len(data["observations"]) for data in outputs])
        print(f"Elapsed Time: {end - start:.2f} s")
        print(f"Mean Episode Length: {total_frames / args.n_traj:.2f}")
        print(f"FPS: {total_frames/(end - start):.2f}")
        print()
        for i, np_data_dict in enumerate(outputs):
            writer.extend_data(np_data_dict)
            if i < args.n_visualize_episodes:
                corner3_image = np.array(np_data_dict["infos/corner3_image"])
                corner_image = np.array(np_data_dict["infos/corner_image"])
                images = np.concatenate((corner3_image, corner_image), axis=2)
                for image in images:
                    cv2.putText(
                        image,
                        f"Task ID: {task_id}",
                        (5, 15),
                        fontFace=cv2.FONT_HERSHEY_PLAIN,
                        fontScale=0.8,
                        color=(20, 20, 20),
                    )
                frames.extend(images)

    if len(frames) > 0:
        Path("video").mkdir(exist_ok=True)
        imageio.mimsave(f"video/{args.task}.mp4", frames, fps=60)

    writer.write_dataset(data_dir / f"{args.task}.hdf5")


if __name__ == "__main__":
    main()
