import multiprocessing as mp
import time
import warnings
import argparse
from pathlib import Path
from collections import defaultdict

import gym
import numpy as np
import imageio.v2
import cv2
from tqdm import tqdm
from omegaconf import OmegaConf

import robosuite
from robosuite.controllers import load_controller_config
from robosuite.utils.placement_samplers import UniformRandomSampler

from collect_data.dataset_writer import DatasetWriter
from common.utils import robot_utils  # need to register manipulation envs
from common.utils.robot_utils import ROBOTS

warnings.filterwarnings("ignore")


def create_policy(robot_name: str):
    """create scripted policy for Lift task

    Args:
        robot_name (str): the name of the robot.

    Returns:
        Callable[Dict]: policy
    """
    wait_steps = 10
    grasped = False
    close = False
    release = False
    scale = 1.

    xy_threshold = 0.005
    z_threshold = 0.027 if robot_name == "UR5e" else (
        0.025 if robot_name == "IIWA" else 0.015)

    alpha = 10

    def policy(obs):
        nonlocal grasped, wait_steps, close, release
        eef_pos = obs["robot0_eef_pos"]
        cube_pos = obs["cubeA_pos"]
        action = np.zeros(4)
        action[3] = -1
        done = False

        if grasped:
            action[3] = 1
            if wait_steps == 0:
                if eef_pos[2] < 0.95:
                    action[2] = 1
                elif eef_pos[2] > 1:
                    action[2] = -1
                else:
                    done = True
                    wait_steps -= 1
            elif wait_steps == -1:
                xy_distance = np.linalg.norm(eef_pos[:2])
                z_distance = np.abs(cube_pos[2])
                speed = np.linalg.norm(obs["robot0_joint_vel"])
                if xy_distance > xy_threshold:
                    action[:2] = -eef_pos[:2] / np.sqrt(xy_distance) * alpha
                else:
                    close = True

                if z_distance > 0.9 and close:
                    action[2] = -(z_distance - 0.85) * 10
                elif speed < 0.5 and close:
                    release = True

                if release:
                    action[3] = -1
            else:
                xy_distance = np.linalg.norm(cube_pos[:2] - eef_pos[:2])
                if xy_distance > xy_threshold:
                    action[:2] = (cube_pos -
                                  eef_pos)[:2] / np.sqrt(xy_distance) * alpha
                wait_steps -= 1
        else:
            xy_distance = np.linalg.norm(cube_pos[:2] - eef_pos[:2])
            z_distance = np.abs(cube_pos[2] - eef_pos[2])
            if z_distance < 0.15 and not close:
                action[2] = 1

            if xy_distance > xy_threshold:
                action[:2] = (cube_pos -
                              eef_pos)[:2] / np.sqrt(xy_distance) * alpha
            else:
                close = True

            if close and z_distance > z_threshold:
                action[2] = (cube_pos -
                             eef_pos)[2] / np.sqrt(z_distance) * alpha
            elif close:
                grasped = True
                close = False
                action[3] = 1

        action = np.clip(action, -scale, scale)
        action[:3] /= scale
        return action, done

    return policy


def rollout_policy(args):
    robot_name, task_id, seed, image_observation, verbose = args
    np.random.seed(seed)

    env_id = f"{robot_name}-Stack-v1"
    env = gym.make(env_id, image_observation=image_observation)
    goal = env.setup_task(goal_id=task_id, start_id=0)
    goal_id = task_id

    while True:
        start = time.time()

        # reset the policy
        policy = create_policy(robot_name)

        # reset the environment to prepare for a rollout
        env.reset()

        # reset data_dict
        np_data_dict = defaultdict(list)

        done = False
        t = 0
        while not done:
            states = env.sim.get_state().flatten()
            obs = env.obs_dict["robot0_proprio-state"]

            np_data_dict["infos/states"].append(states)
            np_data_dict["observations"].append(obs)
            for key, val in env.obs_dict.items():
                np_data_dict[f"infos/{key}"].append(val)
            np_data_dict["infos/goal_id"].append(goal_id)
            np_data_dict["infos/goal"].append(goal)

            action, policy_done = policy(env.obs_dict)
            _, rew, done, _ = env.step(action)

            success = env.get_success()
            done |= policy_done 
            done |= success 

            np_data_dict["actions"].append(action)
            np_data_dict["rewards"].append(rew)
            np_data_dict["timeouts"].append(done)

            t += 1

        end = time.time()
        success = env.get_success()
        if verbose:
            print(f"Total Frames: {t}")
            print(f"Elapsed time: {end-start:.4f} s")
            print(f"Seconds per Frame: {(end-start)/t*1000:.4f} ms/frame")
            print(f"FPS: {t/(end-start):.4f}")
            print(f"Success: {success}")
            print()

        # if success:
        #     break
        if True:
            break

    return np_data_dict


def main():
    args = OmegaConf.create({
        "robot": "Sawyer",
        "n_traj": 10,
        "n_process": 4,
        "image_observation": True,
        "n_visualize_episodes": 5,
        "verbose": True,
    })
    args = OmegaConf.merge(args, OmegaConf.from_cli())
    assert args.robot in ROBOTS
    print(args.robot)

    writer = DatasetWriter()
    max_episode_steps = 180
    num_task_ids = 8

    frames = []
    for task_id in range(num_task_ids):
        args_list = [(
            args.robot,
            task_id,
            np.random.randint(2**15),
            args.image_observation,
            args.verbose,
        ) for _ in range(args.n_traj)]

        bar_format = f"Task ID: {task_id} " + "{l_bar}{bar:64}{r_bar}"
        start = time.time()
        pool = mp.Pool(args.n_process)
        imap = pool.imap(rollout_policy, args_list)
        outputs = list(tqdm(imap, total=args.n_traj, bar_format=bar_format))
        end = time.time()
        print(f"Elapsed Time: {end - start:.2f} s")
        print(f"FPS: {(args.n_traj*max_episode_steps)/(end - start):.2f}")
        print()
        for i, np_data_dict in enumerate(outputs):
            writer.extend_data(np_data_dict)
            if args.image_observation and i < args.n_visualize_episodes:
                agentview_image = np.array(
                    np_data_dict["infos/agentview_image"])[:, ::-1]
                sideview_image = np.array(
                    np_data_dict["infos/sideview_image"])[:, ::-1]
                images = np.concatenate((agentview_image, sideview_image),
                                        axis=2)
                for image in images:
                    cv2.putText(
                        image,
                        f"Task ID: {task_id}",
                        (5, 15),
                        fontFace=cv2.FONT_HERSHEY_PLAIN,
                        fontScale=0.8,
                        color=(220, 220, 220),
                    )
                frames.extend(images)

    if len(frames) > 0:
        Path("video").mkdir(exist_ok=True)
        imageio.mimsave(f"video/{args.robot}_Stack.mp4", frames, fps=60)

    # writer.write_dataset(f"data/robot/{args.robot}_Stack.hdf5")


if __name__ == "__main__":
    main()
