import multiprocessing as mp
import time
import warnings
import argparse
from pathlib import Path
from collections import defaultdict

import gym
import numpy as np
import imageio.v2
import cv2
from tqdm import tqdm
from omegaconf import OmegaConf

import robosuite
from robosuite.controllers import load_controller_config
from robosuite.utils.placement_samplers import UniformRandomSampler

from collect_data.dataset_writer import DatasetWriter
from common.utils import robot_utils  # need to register manipulation envs
from common.utils.robot_utils import ROBOTS

warnings.filterwarnings("ignore")


def create_policy(robot_name: str, z_offset: float):
    """create scripted policy for Lift task

    Args:
        robot_name (str): the name of the robot.

    Returns:
        Callable[Dict]: policy
    """
    grasp_steps = 15 if robot_name == "Sawyer" else 10
    grasped = False
    close = False
    scale = 1.

    xy_threshold = 0.005
    z_threshold = 0.025 if robot_name == "UR5e" else 0.02

    alpha = 15 if robot_name == "IIWA" else 5

    def policy(obs):
        nonlocal grasped, grasp_steps, close
        eef_pos = obs["robot0_eef_pos"]
        cube_pos = obs["cube_pos"]
        action = np.zeros(4)
        action[3] = -1
        done = False

        # Compensate for the end-effector
        # that shifts to the front as it descends.
        if robot_name == "UR5e":
            cube_pos[0] += 0.02
        elif robot_name == "Panda" and cube_pos[0] < 0.1:
            cube_pos[0] += 0.01

        if grasped:
            action[3] = 1
            if grasp_steps < 0 and eef_pos[2] < 1.0+z_offset:
                action[2] = 1
            elif cube_pos[2] > 0.9+z_offset:
                done = True
            else:
                xy_distance = np.linalg.norm(cube_pos[:2] - eef_pos[:2])
                grasp_steps -= 1
        else:
            xy_distance = np.linalg.norm(cube_pos[:2] - eef_pos[:2])
            z_distance = np.abs(cube_pos[2] - eef_pos[2])
            if z_distance < 0.15 and not close:
                action[2] = 1

            if xy_distance > xy_threshold:
                action[:2] = (cube_pos -
                              eef_pos)[:2] / np.sqrt(xy_distance) * alpha
            else:
                close = True

            if close and z_distance > z_threshold:
                action[2] = (cube_pos -
                             eef_pos)[2] / np.sqrt(z_distance) * alpha
            elif close:
                grasped = True
                action[3] = 1

        action = np.clip(action, -scale, scale)
        action[:3] /= scale
        return action, done

    return policy


def rollout_policy(args):
    robot_name, task_id, seed, image_observation, verbose = args
    np.random.seed(seed)

    env_id = f"{robot_name}-Lift-v1"
    env_kwargs = {} if image_observation else {'use_camera_obs': False,
                                               'has_offscreen_renderer': False}
    env = gym.make(env_id, image_observation=image_observation,
                   env_kwargs=env_kwargs)
    _, _, z_offset = env.task_id_to_pos(task_id)
    goal = env.setup_task(goal_id=task_id, start_id=0)
    goal_id = task_id

    while True:
        start = time.time()

        # reset the policy
        policy = create_policy(robot_name, z_offset)

        # reset the environment to prepare for a rollout
        env.reset()

        # reset data_dict
        np_data_dict = defaultdict(list)

        done = False
        t = 0
        while not done:
            states = env.sim.get_state().flatten()
            obs = env.obs_dict["robot0_proprio-state"]

            np_data_dict["infos/states"].append(states)
            np_data_dict["observations"].append(obs)
            for key, val in env.obs_dict.items():
                if not image_observation and "image" in key:
                    continue
                np_data_dict[f"infos/{key}"].append(val)
            np_data_dict["infos/goal_id"].append(goal_id)
            np_data_dict["infos/goal"].append(goal)

            action, policy_done = policy(env.obs_dict)
            _, rew, done, _ = env.step(action)

            done = done | policy_done

            np_data_dict["actions"].append(action)
            np_data_dict["rewards"].append(rew)
            np_data_dict["timeouts"].append(done)

            t += 1

        end = time.time()
        success = env.get_success()
        if verbose:
            print(f"Total Frames: {t}")
            print(f"Elapsed time: {end-start:.4f} s")
            print(f"Seconds per Frame: {(end-start)/t*1000:.4f} ms/frame")
            print(f"FPS: {t/(end-start):.4f}")
            print(f"Success: {success}")
            print()

        if success:
            break

    return np_data_dict


def main():
    args = OmegaConf.create({
        "robot": "Sawyer",
        "n_traj": 200,
        "n_process": 16,
        "image_observation": False,
        "n_visualize_episodes": 5,
        "verbose": False,
    })
    args = OmegaConf.merge(args, OmegaConf.from_cli())
    assert args.robot in ROBOTS
    print(args.robot)

    writer = DatasetWriter()
    max_episode_steps = 120
    num_task_ids = 27

    frames = []
    for task_id in range(num_task_ids):
        args_list = [(
            args.robot,
            task_id,
            np.random.randint(2**15),
            args.image_observation,
            args.verbose,
        ) for _ in range(args.n_traj)]

        bar_format = f"Task ID: {task_id} " + "{l_bar}{bar:64}{r_bar}"
        start = time.time()
        pool = mp.Pool(args.n_process)
        imap = pool.imap(rollout_policy, args_list)
        outputs = list(tqdm(imap, total=args.n_traj, bar_format=bar_format))
        end = time.time()
        print(f"Elapsed Time: {end - start:.2f} s")
        print(f"FPS: {(args.n_traj*max_episode_steps)/(end - start):.2f}")
        print()
        for i, np_data_dict in enumerate(outputs):
            writer.extend_data(np_data_dict)
            if args.image_observation and i < args.n_visualize_episodes:
                agentview_image = np.array(
                    np_data_dict["infos/agentview_image"])[:, ::-1]
                sideview_image = np.array(
                    np_data_dict["infos/sideview_image"])[:, ::-1]
                images = np.concatenate((agentview_image, sideview_image),
                                        axis=2)
                for image in images:
                    cv2.putText(
                        image,
                        f"Task ID: {task_id}",
                        (5, 15),
                        fontFace=cv2.FONT_HERSHEY_PLAIN,
                        fontScale=0.8,
                        color=(220, 220, 220),
                    )
                frames.extend(images)

    if len(frames) > 0:
        Path("video").mkdir(exist_ok=True)
        imageio.mimsave(f"video/{args.robot}_Lift.mp4", frames, fps=60)

    writer.write_dataset(f"data/robot/{args.robot}_Lift.hdf5")


if __name__ == "__main__":
    main()
