import multiprocessing as mp
import time
import warnings
from collections import defaultdict
from pathlib import Path

import gym
import metaworld
import numpy as np
from omegaconf import OmegaConf
from tqdm import tqdm
import cv2
import imageio.v2

from collect_data.metaworld.scripted_policy import make_policy
from collect_data.dataset_writer import DatasetWriter
import common.utils.metaworld_utils

warnings.filterwarnings("ignore")

data_dir = Path("data/metaworld")
data_dir.mkdir(exist_ok=True)


def rollout_policy(args_tuple):
    args, task_id, seed = args_tuple
    np.random.seed(seed)

    while True:
        start = time.time()

        env = gym.make(args.task)
        env.setup_task(task_id, 0)
        goal = np.array([0, 0])
        goal_id = task_id

        policy = make_policy(args.task)

        # reset the environment to prepare for a rollout
        obs = env.reset()
        env.render(
            offscreen=True,
            camera_name="corner3",
            resolution=args.resolution,
        )
        env.render(
            offscreen=True,
            camera_name="corner",
            resolution=args.resolution,
        )
        env.step(np.zeros(4))
        env.unwrapped.t = 0

        # reset data_dict
        np_data_dict = defaultdict(list)

        done = False
        t = 0
        while not done:
            states = env.sim.get_state().flatten()
            obs = env.unwrapped.env._get_obs()

            np_data_dict["infos/states"].append(states)
            np_data_dict["observations"].append(obs)

            obs_dict = env.get_obs_dict()
            for key, val in obs_dict.items():
                np_data_dict[f"infos/{key}"].append(val)

            if "reach-color" in args.task:
                permutation_id = env.get_permutation()
                np_data_dict["infos/permutation_id"].append(permutation_id)

            np_data_dict["infos/goal_id"].append(goal_id)
            np_data_dict["infos/goal"].append(goal)

            if args.image:
                camera_names = [
                    "corner3",
                    "corner",
                ]
                for camera_name in camera_names:
                    image = env.render(
                        offscreen=True,
                        camera_name=camera_name,
                        resolution=args.resolution,
                    )
                    np_data_dict[f"infos/{camera_name}_image"].append(image)

            action = policy(obs_dict)
            obs, rew, done, info = env.step(action)

            success = env.get_success()
            done |= success

            for key, val in info.items():
                np_data_dict[f"infos/{key}"].append(val)

            np_data_dict["actions"].append(action)
            np_data_dict["rewards"].append(rew)
            np_data_dict["timeouts"].append(done)

            t += 1

        end = time.time()
        if args.verbose:
            print(f"Total Frames: {t}")
            print(f"Elapsed time: {end-start:.4f} s")
            print(f"Seconds per Frame: {(end-start)/t*1000:.4f} ms/frame")
            print(f"FPS: {t/(end-start):.4f}")
            print(f"Success: {success}")
            print()

        if success:
            break

    return np_data_dict


def main():
    args = OmegaConf.create({
        "task": "window-close_4-v2",
        "n_traj": 300,
        "n_process": 24,
        "verbose": False,
        "image": True,
        "resolution": (128, 128),
        "n_visualize_episodes": 10,
    })
    args = OmegaConf.merge(args, OmegaConf.from_cli())
    print(args.task)

    writer = DatasetWriter()
    if args.task in ["reach-goal-v2", "reach-color-v2", "window-close_4-v2"]:
        num_task_ids = 4
    elif args.task == "reach-color_simple_3-v2":
        num_task_ids = 3
    elif args.task == "reach-color_simple_2-v2":
        num_task_ids = 2
    else:
        num_task_ids = 1

    frames = []
    for task_id in range(num_task_ids):
        args_list = [(
            args,
            task_id,
            np.random.randint(2**15),
        ) for _ in range(args.n_traj)]

        bar_format = f"Task ID: {task_id} " + "{l_bar}{bar:64}{r_bar}"
        start = time.time()
        pool = mp.Pool(args.n_process)
        imap = pool.imap(rollout_policy, args_list)
        outputs = list(tqdm(imap, total=args.n_traj, bar_format=bar_format))
        end = time.time()
        total_frames = sum([len(data["observations"]) for data in outputs])
        print(f"Elapsed Time: {end - start:.2f} s")
        print(f"Mean Episode Length: {total_frames / args.n_traj:.2f}")
        print(f"FPS: {total_frames/(end - start):.2f}")
        print()
        for i, np_data_dict in enumerate(outputs):
            writer.extend_data(np_data_dict)
            if i < args.n_visualize_episodes:
                corner3_image = np.array(np_data_dict["infos/corner3_image"])
                corner_image = np.array(np_data_dict["infos/corner_image"])
                images = np.concatenate((corner3_image, corner_image), axis=2)
                for image in images:
                    cv2.putText(
                        image,
                        f"Task ID: {task_id}",
                        (5, 15),
                        fontFace=cv2.FONT_HERSHEY_PLAIN,
                        fontScale=0.8,
                        color=(20, 20, 20),
                    )
                frames.extend(images)

    if len(frames) > 0:
        Path("video").mkdir(exist_ok=True)
        imageio.mimsave(f"video/{args.task}.mp4", frames, fps=60)

    writer.write_dataset(data_dir / f"{args.task}.hdf5")


if __name__ == "__main__":
    main()
