from matplotlib.pyplot import axis
import numpy as np
import os
import random
import torch

from hasac.envs.dexhands.DexterousHands.bidexhands.utils.torch_jit_utils import *
from hasac.envs.dexhands.DexterousHands.bidexhands.tasks.hand_base.base_task import (
    BaseTask,
)

from isaacgym import gymtorch
from isaacgym import gymapi


def obtain_task_info(task_name):
    if task_name == "hand_over":
        hand_start_pose = [0, 0, 0.5, 0, 0, 0]
        another_hand_start_pose = [0, -1, 0.5, 0, 0, 3.1415]
        object_pose = [0, -0.39, 0.54, 0, 0, 0]
        goal_pose = [0, -0.64, 0.54, 0, -0.0, 0.0]
        another_object_pose = [0, 0, 10, 0, 0, 0]
        another_goal_pose = [0, 0, 10, 0, 0, 0]
        table_pose_dim = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_type = "egg"

    if task_name == "catch_underarm":
        hand_start_pose = [0, 0, 0.5, 0, 0, 0]
        another_hand_start_pose = [0, -1.15, 0.5, 0, 0, 3.1415]
        object_pose = [0, -0.39, 0.54, 0, 0, 0]
        goal_pose = [0, -0.79, 0.54, 0, -0.0, 0.0]
        another_object_pose = [0, 0, 10, 0, 0, 0]
        another_goal_pose = [0, 0, 10, 0, 0, 0]
        table_pose_dim = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_type = "egg"

    if task_name == "two_catch_underarm":
        hand_start_pose = [0, 0, 0.5, 0, 0, 0]
        another_hand_start_pose = [0, -1, 0.5, 0, 0, 3.1415]
        object_pose = [0, -0.39, 0.54, 0, 0, 0]
        another_object_pose = [0, -0.64, 0.54, 0, 0, 0]
        goal_pose = [0, -0.64, 0.54, -0.0, -0.0, 0.0]
        another_goal_pose = [-0.05, -0.39, 0.54, -0.0, -0.0, 0.0]
        table_pose_dim = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_type = "egg"

    if task_name == "catch_abreast":
        hand_start_pose = [0, -0.55, 0.5, 0, 0.3925, -1.57]
        another_hand_start_pose = [0, -1, 0.5, 0, -0.3925, -1.57]
        object_pose = [-0.39, -0.55, 0.54, 0, 0, 0]
        goal_pose = [-0.39, -1, 0.54, 0, 0, 0]
        another_object_pose = [0, 0, 10, 0, 0, 0]
        another_goal_pose = [0, 0, 10, 0, 0, 0]
        table_pose_dim = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_type = "egg"

    if task_name == "door_open_outward":
        hand_start_pose = [0.55, 0.2, 0.6, 3.14159, 1.57, 1.57]
        another_hand_start_pose = [0.55, -0.2, 0.6, 3.14159, -1.57, 1.57]
        object_pose = [0.0, 0.0, 0.7, 0, 0.0, 0.0]
        goal_pose = [0, -0.39, 10, 0, 0, 0]
        another_object_pose = [0, 0, 10, 0, 0, 0]
        another_goal_pose = [0, 0, 10, 0, 0, 0]
        table_pose_dim = [0.0, -0.6, 0, 0, 0, 0, 0.3, 0.3, 0.0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_asset_options.fix_base_link = True
        object_asset_options.disable_gravity = True
        object_asset_options.use_mesh_materials = True
        object_asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        object_asset_options.override_com = True
        object_asset_options.override_inertia = True
        object_asset_options.vhacd_enabled = True
        object_asset_options.vhacd_params = gymapi.VhacdParams()
        object_asset_options.vhacd_params.resolution = 100000
        object_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
        object_type = "door"

    if task_name == "door_close_inward":
        hand_start_pose = [0.55, 0.2, 0.6, 3.14159, 1.57, 1.57]
        another_hand_start_pose = [0.55, -0.2, 0.6, 3.14159, -1.57, 1.57]
        object_pose = [0.0, 0.0, 0.7, 0, 3.14159, 0.0]
        goal_pose = [0, -0.39, 10, 0, 0, 0]
        another_object_pose = [0, 0, 10, 0, 0, 0]
        another_goal_pose = [0, 0, 10, 0, 0, 0]
        table_pose_dim = [0.0, -0.6, 0, 0, 0, 0, 0.3, 0.3, 0.0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_asset_options.fix_base_link = True
        object_asset_options.disable_gravity = True
        object_asset_options.use_mesh_materials = True
        object_asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        object_asset_options.override_com = True
        object_asset_options.override_inertia = True
        object_asset_options.vhacd_enabled = True
        object_asset_options.vhacd_params = gymapi.VhacdParams()
        object_asset_options.vhacd_params.resolution = 100000
        object_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
        object_type = "door"

    if task_name == "door_open_inward":
        hand_start_pose = [0.55, 0.2, 0.6, 3.14159, 1.57, 1.57]
        another_hand_start_pose = [0.55, -0.2, 0.6, 3.14159, -1.57, 1.57]
        object_pose = [0.0, 0.0, 0.7, 0, 3.14159, 0.0]
        goal_pose = [0, -0.39, 10, 0, 0, 0]
        another_object_pose = [0, 0, 10, 0, 0, 0]
        another_goal_pose = [0, 0, 10, 0, 0, 0]
        table_pose_dim = [0.0, -0.6, 0, 0, 0, 0, 0.3, 0.3, 0.0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_asset_options.fix_base_link = True
        object_asset_options.disable_gravity = True
        object_asset_options.use_mesh_materials = True
        object_asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        object_asset_options.override_com = True
        object_asset_options.override_inertia = True
        object_asset_options.vhacd_enabled = True
        object_asset_options.vhacd_params = gymapi.VhacdParams()
        object_asset_options.vhacd_params.resolution = 100000
        object_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
        object_type = "door"

    if task_name == "door_close_outward":
        hand_start_pose = [0.55, 0.2, 0.6, 3.14159, 1.57, 1.57]
        another_hand_start_pose = [0.55, -0.2, 0.6, 3.14159, -1.57, 1.57]
        object_pose = [0.0, 0.0, 0.7, 0, 0.0, 0.0]
        goal_pose = [0, -0.39, 10, 0, 0, 0]
        another_object_pose = [0, 0, 10, 0, 0, 0]
        another_goal_pose = [0, 0, 10, 0, 0, 0]
        table_pose_dim = [0.0, -0.6, 0, 0, 0, 0, 0.3, 0.3, 0.0]
        object_asset_options = gymapi.AssetOptions()
        object_asset_options.density = 500
        object_asset_options.fix_base_link = True
        object_asset_options.disable_gravity = True
        object_asset_options.use_mesh_materials = True
        object_asset_options.mesh_normal_mode = gymapi.COMPUTE_PER_VERTEX
        object_asset_options.override_com = True
        object_asset_options.override_inertia = True
        object_asset_options.vhacd_enabled = True
        object_asset_options.vhacd_params = gymapi.VhacdParams()
        object_asset_options.vhacd_params.resolution = 100000
        object_asset_options.default_dof_drive_mode = gymapi.DOF_MODE_NONE
        object_type = "door"

    return (
        hand_start_pose,
        another_hand_start_pose,
        object_pose,
        goal_pose,
        another_object_pose,
        another_goal_pose,
        table_pose_dim,
        object_asset_options,
        object_type,
    )


@torch.jit.script
def compute_hand_reward(
    rew_buf,
    reset_buf,
    reset_goal_buf,
    progress_buf,
    successes,
    consecutive_successes,
    max_episode_length: float,
    object_pos,
    object_rot,
    target_pos,
    target_rot,
    another_object_pos,
    another_object_rot,
    another_target_pos,
    another_target_rot,
    object_left_handle_pos,
    object_right_handle_pos,
    left_hand_pos,
    right_hand_pos,
    right_hand_ff_pos,
    right_hand_mf_pos,
    right_hand_rf_pos,
    right_hand_lf_pos,
    right_hand_th_pos,
    left_hand_ff_pos,
    left_hand_mf_pos,
    left_hand_rf_pos,
    left_hand_lf_pos,
    left_hand_th_pos,
    dist_reward_scale: float,
    rot_reward_scale: float,
    rot_eps: float,
    actions,
    action_penalty_scale: float,
    success_tolerance: float,
    reach_goal_bonus: float,
    fall_dist: float,
    fall_penalty: float,
    max_consecutive_successes: int,
    av_factor: float,
    this_task: str,
):
    # Distance from the hand to the object
    if this_task in [
        "catch_underarm",
        "two_catch_underarm",
        "catch_abreast",
        "catch_over2underarm",
    ]:
        goal_dist = torch.norm(target_pos - object_pos, p=2, dim=-1)
        another_goal_dist = torch.norm(
            another_target_pos - another_object_pos, p=2, dim=-1
        )

        # Orientation alignment for the cube in hand and goal cube
        quat_diff = quat_mul(object_rot, quat_conjugate(target_rot))
        rot_dist = 2.0 * torch.asin(
            torch.clamp(torch.norm(quat_diff[:, 0:3], p=2, dim=-1), max=1.0)
        )
        another_quat_diff = quat_mul(
            another_object_rot, quat_conjugate(another_target_rot)
        )
        another_rot_dist = 2.0 * torch.asin(
            torch.clamp(torch.norm(another_quat_diff[:, 0:3], p=2, dim=-1), max=1.0)
        )

        dist_rew = goal_dist
        # rot_rew = 1.0/(torch.abs(rot_dist) + rot_eps) * rot_reward_scale

        action_penalty = torch.sum(actions**2, dim=-1)

        # Total reward is: position distance + orientation alignment + action regularization + success bonus + fall penalty
        reward = torch.exp(
            -0.2
            * (
                (dist_rew * dist_reward_scale + rot_dist)
                + (another_goal_dist * dist_reward_scale)
            )
        )

        # Find out which envs hit the goal and update successes count
        goal_resets = torch.where(
            torch.abs(goal_dist) <= 0.03,
            torch.ones_like(reset_goal_buf),
            reset_goal_buf,
        )
        successes = successes + goal_resets

        # Success bonus: orientation is within `success_tolerance` of goal orientation
        reward = torch.where(goal_resets == 1, reward + reach_goal_bonus, reward)

        # Fall penalty: distance to the goal is larger than a threashold
        reward = torch.where(object_pos[:, 2] <= 0.1, reward + fall_penalty, reward)
        reward = torch.where(
            another_object_pos[:, 2] <= 0.1, reward + fall_penalty, reward
        )

        # Check env termination conditions, including maximum success number
        resets = torch.where(
            object_pos[:, 2] <= 0.1, torch.ones_like(reset_buf), reset_buf
        )
        resets = torch.where(
            another_object_pos[:, 2] <= 0.1, torch.ones_like(resets), resets
        )

        if max_consecutive_successes > 0:
            # Reset progress buffer on goal envs if max_consecutive_successes > 0
            progress_buf = torch.where(
                torch.abs(rot_dist) <= success_tolerance,
                torch.zeros_like(progress_buf),
                progress_buf,
            )
            resets = torch.where(
                successes >= max_consecutive_successes, torch.ones_like(resets), resets
            )
        resets = torch.where(
            progress_buf >= max_episode_length, torch.ones_like(resets), resets
        )

        # Apply penalty for not reaching the goal
        if max_consecutive_successes > 0:
            reward = torch.where(
                progress_buf >= max_episode_length, reward + 0.5 * fall_penalty, reward
            )

        num_resets = torch.sum(resets)
        finished_cons_successes = torch.sum(successes * resets.float())

        cons_successes = torch.where(
            num_resets > 0,
            av_factor * finished_cons_successes / num_resets
            + (1.0 - av_factor) * consecutive_successes,
            consecutive_successes,
        )

        return reward, resets, goal_resets, progress_buf, successes, cons_successes

    if this_task in ["door_open_inward"]:
        # Distance from the hand to the object
        goal_dist = torch.norm(target_pos - object_pos, p=2, dim=-1)
        # goal_dist = target_pos[:, 2] - object_pos[:, 2]

        right_hand_dist = torch.norm(
            object_right_handle_pos - right_hand_pos, p=2, dim=-1
        )
        left_hand_dist = torch.norm(object_left_handle_pos - left_hand_pos, p=2, dim=-1)

        right_hand_finger_dist = (
            torch.norm(object_right_handle_pos - right_hand_ff_pos, p=2, dim=-1)
            + torch.norm(object_right_handle_pos - right_hand_mf_pos, p=2, dim=-1)
            + torch.norm(object_right_handle_pos - right_hand_rf_pos, p=2, dim=-1)
            + torch.norm(object_right_handle_pos - right_hand_lf_pos, p=2, dim=-1)
            + torch.norm(object_right_handle_pos - right_hand_th_pos, p=2, dim=-1)
        )
        left_hand_finger_dist = (
            torch.norm(object_left_handle_pos - left_hand_ff_pos, p=2, dim=-1)
            + torch.norm(object_left_handle_pos - left_hand_mf_pos, p=2, dim=-1)
            + torch.norm(object_left_handle_pos - left_hand_rf_pos, p=2, dim=-1)
            + torch.norm(object_left_handle_pos - left_hand_lf_pos, p=2, dim=-1)
            + torch.norm(object_left_handle_pos - left_hand_th_pos, p=2, dim=-1)
        )

        # Orientation alignment for the cube in hand and goal cube
        # quat_diff = quat_mul(object_rot, quat_conjugate(target_rot))
        # rot_dist = 2.0 * torch.asin(torch.clamp(torch.norm(quat_diff[:, 0:3], p=2, dim=-1), max=1.0))

        right_hand_dist_rew = right_hand_finger_dist
        left_hand_dist_rew = left_hand_finger_dist

        # rot_rew = 1.0/(torch.abs(rot_dist) + rot_eps) * rot_reward_scale

        action_penalty = torch.sum(actions**2, dim=-1)

        # Total reward is: position distance + orientation alignment + action regularization + success bonus + fall penalty
        # reward = torch.exp(-0.05*(up_rew * dist_reward_scale)) + torch.exp(-0.05*(right_hand_dist_rew * dist_reward_scale)) + torch.exp(-0.05*(left_hand_dist_rew * dist_reward_scale))
        up_rew = torch.zeros_like(right_hand_dist_rew)
        up_rew = torch.where(
            right_hand_finger_dist < 0.5,
            torch.where(
                left_hand_finger_dist < 0.5,
                torch.abs(object_right_handle_pos[:, 1] - object_left_handle_pos[:, 1])
                * 2,
                up_rew,
            ),
            up_rew,
        )
        # up_rew =  torch.where(right_hand_finger_dist <= 0.3, torch.norm(bottle_cap_up - bottle_pos, p=2, dim=-1) * 30, up_rew)

        # reward = torch.exp(-0.1*(right_hand_dist_rew * dist_reward_scale)) + torch.exp(-0.1*(left_hand_dist_rew * dist_reward_scale))
        reward = 2 - right_hand_dist_rew - left_hand_dist_rew + up_rew

        # resets = torch.where(right_hand_finger_dist >= 1.5, torch.ones_like(reset_buf), reset_buf)
        # resets = torch.where(left_hand_finger_dist >= 1.5, torch.ones_like(resets), resets)
        # resets = torch.where(left_hand_dist >= 0.2, torch.ones_like(resets), resets)

        # print(right_hand_dist_rew[0])
        # print(left_hand_dist_rew[0])
        # print(up_rew[0])

        # Find out which envs hit the goal and update successes count
        resets = torch.where(
            progress_buf >= max_episode_length, torch.ones_like(reset_buf), reset_buf
        )

        goal_resets = torch.zeros_like(resets)

        num_resets = torch.sum(resets)
        finished_cons_successes = torch.sum(successes * resets.float())

        cons_successes = torch.where(
            num_resets > 0,
            av_factor * finished_cons_successes / num_resets
            + (1.0 - av_factor) * consecutive_successes,
            consecutive_successes,
        )

        return reward, resets, goal_resets, progress_buf, successes, cons_successes

    if this_task in ["lift_pot"]:
        # Distance from the hand to the object
        goal_dist = torch.norm(target_pos - object_pos, p=2, dim=-1)
        # goal_dist = target_pos[:, 2] - object_pos[:, 2]
        right_hand_dist = torch.norm(
            object_right_handle_pos - right_hand_pos, p=2, dim=-1
        )
        left_hand_dist = torch.norm(
            object_right_handle_pos - left_hand_pos, p=2, dim=-1
        )
        # Orientation alignment for the cube in hand and goal cube
        # quat_diff = quat_mul(object_rot, quat_conjugate(target_rot))
        # rot_dist = 2.0 * torch.asin(torch.clamp(torch.norm(quat_diff[:, 0:3], p=2, dim=-1), max=1.0))

        right_hand_dist_rew = right_hand_dist
        left_hand_dist_rew = left_hand_dist

        # rot_rew = 1.0/(torch.abs(rot_dist) + rot_eps) * rot_reward_scale

        action_penalty = torch.sum(actions**2, dim=-1)

        # Total reward is: position distance + orientation alignment + action regularization + success bonus + fall penalty
        # reward = torch.exp(-0.05*(up_rew * dist_reward_scale)) + torch.exp(-0.05*(right_hand_dist_rew * dist_reward_scale)) + torch.exp(-0.05*(left_hand_dist_rew * dist_reward_scale))
        up_rew = torch.zeros_like(right_hand_dist_rew)
        up_rew = torch.where(
            right_hand_dist < 0.08,
            torch.where(left_hand_dist < 0.08, 3 * (0.985 - goal_dist), up_rew),
            up_rew,
        )

        reward = 0.2 - right_hand_dist_rew - left_hand_dist_rew + up_rew

        resets = torch.where(
            object_pos[:, 2] <= 0.3, torch.ones_like(reset_buf), reset_buf
        )
        resets = torch.where(right_hand_dist >= 0.2, torch.ones_like(resets), resets)
        resets = torch.where(left_hand_dist >= 0.2, torch.ones_like(resets), resets)

        # Find out which envs hit the goal and update successes count
        resets = torch.where(
            progress_buf >= max_episode_length, torch.ones_like(resets), resets
        )

        goal_resets = torch.zeros_like(resets)

        num_resets = torch.sum(resets)
        finished_cons_successes = torch.sum(successes * resets.float())

        cons_successes = torch.where(
            num_resets > 0,
            av_factor * finished_cons_successes / num_resets
            + (1.0 - av_factor) * consecutive_successes,
            consecutive_successes,
        )

        return reward, resets, goal_resets, progress_buf, successes, cons_successes


def obtain_task_dof_info(this_task, env_ids, num_each_envs, device):
    task_dof_inits = to_torch([0, 0], device=device).repeat((len(env_ids), 1))
    for i in range(len(env_ids)):
        if env_ids[i] < num_each_envs * 2:
            task_dof_inits[i] = to_torch([0, 0], device=device)
        if env_ids[i] >= num_each_envs * 2:
            task_dof_inits[i] = to_torch([1.57, 1.57], device=device)
    return task_dof_inits
