import os
import yaml
import torch
import numpy as np

from isaacgym import gymapi

import storm.differentiable_robot_model.coordinate_transform as tf

from goal_set_planning.sim.gym_robot import GymRobot
from goal_set_planning.sim.gym_world import GymWorld
from goal_set_planning.storm.arm_rollout import ArmRollout
from goal_set_planning.util.pose import Pose, QuaternionFormat
from goal_set_planning.util.misc import PoseDistance
from goal_set_planning.util.acronym import pick_successful_grasps


class TerminalPathChecker(object):
    def __init__(self, rot_scale=0.06, window=5, thresh=0.05):
        self.window = window
        self.thresh = thresh

        self.dists = []
        self.dist_fn = PoseDistance(rot_scale=rot_scale, squared=False)

    def __call__(self, state_dict, goal_samples=None):
        ee_pos, ee_rot = state_dict["ee_pos_seq"][0, 1, :], state_dict["ee_rot_seq"][0, 1, :, :]
        term_pos, term_rot = state_dict["ee_pos_seq"][0, -1, :], state_dict["ee_rot_seq"][0, -1, :, :]
        ee_pose = torch.cat([ee_pos, ee_rot.flatten()])
        term_pose = torch.cat([term_pos, term_rot.flatten()])

        self.dists.append(self.dist_fn(ee_pose, term_pose)[0].item())

        # If we don't have enough distances yet, return not converged.
        if len(self.dists) < self.window:
            return False

        # Don't store more than window length.
        if len(self.dists) > self.window:
            self.dists.pop(0)

        dist_to_term = np.median(self.dists)

        # print("\tMedian dist to end:", dist_to_term)

        return dist_to_term <= self.thresh


class TerminalPoseChecker(object):
    def __init__(self, rot_scale=0.2, window=5, thresh=0.02):
        self.window = window
        self.thresh = thresh

        self.dists = []
        self.dist_fn = PoseDistance(rot_scale=rot_scale, squared=False)

    def __call__(self, state_dict, goal_pose=None):
        ee_pos, ee_rot = state_dict["ee_pos_seq"][0, 1, :], state_dict["ee_rot_seq"][0, 1, :, :]
        ee_pose = torch.cat([ee_pos, ee_rot.flatten()])

        self.dists.append(self.dist_fn(ee_pose, goal_pose)[0].item())

        # If we don't have enough distances yet, return not converged.
        if len(self.dists) < self.window:
            return False

        # Don't store more than window length.
        if len(self.dists) > self.window:
            self.dists.pop(0)

        dist_to_term = np.median(self.dists)

        # print("\tMedian dist to end:", dist_to_term)

        return dist_to_term <= self.thresh


def compute_pregrasp_poses(goal_tf, offset=0.1):
    offset_pose = Pose(pos=[0, 0, -offset], quat=[0, 0, 0, 1], as_format=QuaternionFormat.XYZW)
    offset_tf = goal_tf * offset_pose.to(goal_tf.tensor_args)
    return offset_tf


def closest_goal(ee_pos, ee_rot, goal_samples, rot_scale=0.2):
    dist_fn = PoseDistance(rot_scale=rot_scale, squared=False)
    ee_pose = torch.cat([ee_pos, ee_rot.flatten()])
    dists = dist_fn(ee_pose, goal_samples)
    min_dist, min_idx = dists.min(0)

    return min_dist.item(), min_idx.item()


def visualize_grasp(world, grasp_pose, gripper_file, color=[0, 0, 1], name="end_effector", local_tf=None):
    # If a local transform was provided, apply it.
    if local_tf is not None:
        grasp_pose = grasp_pose * local_tf

    asset_options = gymapi.AssetOptions()
    asset_options.fix_base_link = True
    asset_options.collapse_fixed_joints = True

    world.spawn_object(name, gripper_file, grasp_pose.gym_tf(), asset_root="assets",
                       color=color, collision_group=1, asset_options=asset_options)


def visualize_grasps(world, grasp_tf, gripper_file, num_viz=None, local_tf=None,
                     name="end_effector", color=[0, 1, 0]):
    # If a local transform was provided, apply it.
    if local_tf is not None:
        grasp_tf = grasp_tf * local_tf

    goal_pos_samples = grasp_tf.translation().cpu().numpy()
    goal_rot_samples = grasp_tf.rotation().cpu().numpy()
    N = goal_pos_samples.shape[0]

    if num_viz is not None:
        num_viz = min(num_viz, N)
        indices = np.random.choice(np.arange(N), num_viz, replace=False)
    else:
        indices = np.arange(N)

    asset_options = gymapi.AssetOptions()
    asset_options.fix_base_link = True
    asset_options.collapse_fixed_joints = True

    for i in indices:
        pose = Pose(pos=goal_pos_samples[i, :], rot=goal_rot_samples[i, :, :])
        world.spawn_object("{}_{}".format(name, i), gripper_file, pose.gym_tf(), asset_root="assets",
                           color=color, collision_group=1, asset_options=asset_options)


# ********************************
#      Environment Loaders
# ********************************


def setup_rollout(args, model_params):
    rollout = ArmRollout(args.num_particles, args.horizon, model_params, tensor_args=args.tensor_kwargs)

    args.steps = rollout.dynamics_model.num_traj_points
    args.dof = rollout.dynamics_model.n_dofs
    args.dt = rollout.dynamics_model.traj_dt
    bounds = [rollout.dynamics_model.state_lower_bounds[:rollout.n_dofs * 3].unsqueeze(0),
              rollout.dynamics_model.state_upper_bounds[:rollout.n_dofs * 3].unsqueeze(0)]
    args.bounds = torch.cat(bounds, dim=0).T

    return rollout


def pick_samples(num_goals, scene_config, tensor_kwargs, acronym_root=None):
    if acronym_root is None:
        acronym_root = scene_config["acronym_root"]

    obj_id = scene_config["goal"]["object"]["id"]
    obj_scale = scene_config["goal"]["object"]["scale"]
    obj_type = scene_config["goal"]["object"]["category"]
    obj_full_name = "{}_{}_{}".format(obj_type, obj_id, obj_scale)
    acronym_grasp_file = os.path.join(acronym_root, "grasps", obj_full_name + ".h5")

    grasp_samples = pick_successful_grasps(acronym_grasp_file, num=num_goals)
    goal_pos_samples = torch.as_tensor(grasp_samples[:, :3, 3], **tensor_kwargs)
    goal_rot_samples = torch.as_tensor(grasp_samples[:, :3, :3], **tensor_kwargs)

    # Get the goal pose.
    target_quat = scene_config["goal"]["pose"]["quat"]  # [0, 0, 0, 1]  # x, y, z, w
    target_pos = scene_config["goal"]["pose"]["pos"]  # [0.6, 0, -0.05]
    target_pose = Pose(pos=target_pos, quat=target_quat, as_format=QuaternionFormat.XYZW)

    # Transform the goal samples from the mug frame to the robot frame.
    goal_tf = tf.CoordinateTransform(trans=goal_pos_samples, rot=goal_rot_samples,
                                     tensor_args=tensor_kwargs)
    goal_tf = target_pose.to(tensor_kwargs) * goal_tf

    return goal_tf


def load_scene(scene_config, args, gym, sim, env, viewer=None, pick_grasps=True,
               target_collision=False, acronym_root=None):
    robot_pos = scene_config["robot"]["pose"]["pos"]
    robot_quat = scene_config["robot"]["pose"]["quat"]  # x, y, z, w
    robot_pose = Pose(pos=robot_pos, quat=robot_quat, as_format=QuaternionFormat.XYZW)

    args.world = scene_config["world_file"]
    world = GymWorld(gym, sim, env, args.world, world_tf=robot_pose.gym_tf(), collision_group=0, viewer=viewer)
    robot = GymRobot(gym, sim, env, collision_group=0, tensor_kwargs=args.tensor_kwargs)
    robot.spawn_robot(robot_pose.gym_tf(), asset_root=args.asset_root, asset_file=args.robot_file,
                      init_state=scene_config["robot"]["init_state"])

    # Get the goal information.
    if "goal" in scene_config:
        if acronym_root is None:
            acronym_root = scene_config["acronym_root"]

        obj_id = scene_config["goal"]["object"]["id"]
        obj_scale = scene_config["goal"]["object"]["scale"]
        obj_type = scene_config["goal"]["object"]["category"]
        obj_full_name = "{}_{}_{}".format(obj_type, obj_id, obj_scale)
        acronym_grasp_file = os.path.join(acronym_root, "grasps", obj_full_name + ".h5")
        acronym_urdf = os.path.join("objects", obj_type, obj_full_name.replace(".", "") + ".urdf")

        # Object asset options.
        asset_options = gymapi.AssetOptions()
        asset_options.use_mesh_materials = True
        asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        asset_options.override_inertia = True
        asset_options.override_com = True
        if not obj_type == "CerealBox":  # The cereal box decomposition doesn't work well because it's already a box.
            asset_options.vhacd_enabled = True
            asset_options.vhacd_params = gymapi.VhacdParams()
            asset_options.vhacd_params.resolution = 300000

        # Get the goal pose.
        target_quat = scene_config["goal"]["pose"]["quat"]  # [0, 0, 0, 1]  # x, y, z, w
        target_pos = scene_config["goal"]["pose"]["pos"]  # [0.6, 0, -0.05]
        target_pose = Pose(pos=target_pos, quat=target_quat, as_format=QuaternionFormat.XYZW)

        # Add the goal to the scene.
        world.spawn_object("target", acronym_urdf, target_pose.gym_tf(), collision_group=0,
                           asset_root=acronym_root, asset_options=asset_options)

        collision_params = None
        if "collision" in scene_config["goal"].keys() and target_collision:
            collision_params = scene_config["goal"]["collision"]
            pos, quat = collision_params["pose"][:3], collision_params["pose"][3:]
            collision_pose = Pose(pos=pos, quat=quat, as_format=QuaternionFormat.XYZW)
            collision_pose = target_pose * collision_pose
            pos, quat = collision_pose.position(as_tensor=False), collision_pose.quaternion(as_tensor=False)
            dims = collision_params["dims"]

            collision_params = {"target_collision": {"dims": dims, "pose": pos + quat}}
            # world.add_box("target_collision", dims, pos, quat, [0.3, 0.6, 0.3],
            #               collision_group=2)

        # Get the acronym goals.
        goal_tf = None
        if pick_grasps:
            goal_tf = pick_samples(args.num_goals, scene_config, args.tensor_kwargs, acronym_root=acronym_root)
            # grasp_samples = pick_successful_grasps(acronym_grasp_file, num=args.num_goals)
            # goal_pos_samples = torch.as_tensor(grasp_samples[:, :3, 3], **args.tensor_kwargs)
            # goal_rot_samples = torch.as_tensor(grasp_samples[:, :3, :3], **args.tensor_kwargs)

            # # Transform the goal samples from the mug frame to the robot frame.
            # goal_tf = tf.CoordinateTransform(trans=goal_pos_samples, rot=goal_rot_samples,
            #                                  tensor_args=args.tensor_kwargs)
            # goal_tf = target_pose.to(args.tensor_kwargs) * goal_tf

        if target_collision:
            return robot, world, goal_tf, collision_params
        else:
            return robot, world, goal_tf
    else:
        return robot, world, None


# ********************************
#          Control Loop
# ********************************


def wait(seconds, world, robot, dt=0.02, dof=7):
    time_elapsed = 0
    traj = []
    state_tensor = robot.get_state_tensor(dof=dof).cpu().numpy()

    while time_elapsed < seconds:
        if world.viewer_closed():
            break

        # Keep robot still at initial position.
        robot.hold()

        traj.append(state_tensor)

        time_elapsed += dt

        # Simulate.
        world.simulate()
        world.step_graphics()

    return np.stack(traj)


def control_loop(controller, world, robot, model_params,
                 goal=None, termination_fn=None, save_data=None, timeout=None,
                 n_iters=50, min_iters=5, iter_step=10, dof=7,
                 bbox_lims=None, show_traj=True, headless=False):
    # Make directory to store the trajectory data if necessary.
    if save_data is not None:
        os.makedirs(save_data, exist_ok=True)

    it = 0
    state_dict = None
    traj = []
    while not world.viewer_closed():
        try:
            # Get the current state.
            state_tensor = robot.get_state_tensor(dof=dof).unsqueeze(0)

            traj.append(state_tensor.squeeze().cpu().numpy())

            # Calculate the next control command.
            shift = 1 if it > 0 else 0
            action_seq, action_idx = controller.optimize(state_tensor, shift_steps=shift, n_iters=n_iters,
                                                         return_idx=True)

            if save_data is not None:
                particles_path = os.path.join(save_data, "particles_{:04d}.npy".format(it))
                np.save(particles_path, controller.action_particles().cpu().numpy())

            # Simulate action sequence in order to get next state for simulator.
            state_dict = controller.rollout(state_tensor, action_seq.unsqueeze(0))
            next_state = state_dict['state_seq'][0, 1, :]
            next_pos, next_vel, _ = next_state.split(dof)[:3]

            robot.send_position_cmd(next_pos)
            robot.set_state(pos=next_pos, vel=next_vel)  # , acc=action_seq[0])

            # Simulate.
            world.simulate()

            if not headless:
                world.clear_lines()
                if show_traj:
                    ee_traj = controller.rollout()["ee_pos_seq"]
                    world.draw_trajectories(ee_traj, highlight_idx=action_idx, highlight_color=[0, 1, 0])
                if bbox_lims is not None:
                    world.draw_bbox(bbox_lims)

                world.step_graphics()

            it += 1
            n_iters = max(min_iters, n_iters - iter_step)

            if it % 20 == 0:
                print("\tIteration", it)

            # Check termination function, if given.
            if termination_fn is not None:
                if termination_fn(state_dict, goal):
                    break

            # Check for timeout, if given.
            if timeout is not None:
                if it > timeout:
                    print("Control loop timed out after {} iterations :(".format(it))
                    break

        except KeyboardInterrupt:
            break

    # Save the final state.
    state_tensor = robot.get_state_tensor(dof=dof)
    traj.append(state_tensor.cpu().numpy())

    state_dict["trajectory"] = np.stack(traj)

    return state_dict


# ********************************
#          Data Savers
# ********************************


def save_seq_data(args, state, model_params, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    if torch.is_tensor(state):
        state = state.squeeze().detach().cpu().numpy()

    data = {"optim_type": args.optim, "world_file": args.world, "scene_file": args.scene,
            "state": state.tolist(),
            "num_particles": args.num_particles, "horizon": args.horizon,
            "model_params": model_params}

    with open(os.path.join(save_dir, "data.yaml"), 'w') as f:
        yaml.dump(data, f)
