from pathlib import Path
from time import time

import numpy as np
import pytest

from dex_retargeting.constants import ROBOT_NAMES, get_default_config_path, RetargetingType, HandType, RobotName
from dex_retargeting.optimizer import VectorOptimizer, PositionOptimizer, Optimizer
from dex_retargeting.retargeting_config import RetargetingConfig
from dex_retargeting.robot_wrapper import RobotWrapper

class TestOptimizer:
    def __init__(self, hand_type):
        np.set_printoptions(precision=4)
        config_dir = Path(__file__).parent / "dex_retargeting" / "configs"
        robot_dir = Path(__file__).parent / "dex_retargeting" / "robots" / "hands"
        RetargetingConfig.set_default_urdf_dir(str(robot_dir.absolute()))
        DEXPILOT_ROBOT_NAMES = ROBOT_NAMES.copy()
        DEXPILOT_ROBOT_NAMES.remove(RobotName.ability)
        if hand_type == "shadow":
            self.robot_name = RobotName.shadow
            self.left_hand_type = HandType.right
            self.right_hand_type = HandType.right

            self.left_last_qpos = np.array([0 for i in range(24)])
            self.right_last_qpos = np.array([0 for i in range(24)])

    @staticmethod
    def sample_qpos(optimizer: Optimizer):
        joint_eps = 1e-5
        robot = optimizer.robot
        adaptor = optimizer.adaptor
        joint_limit = robot.joint_limits
        random_qpos = np.random.uniform(joint_limit[:, 0], joint_limit[:, 1])
        if adaptor is not None:
            random_qpos = adaptor.forward_qpos(random_qpos)

        init_qpos = np.clip(
            random_qpos + np.random.randn(robot.dof) * 0.5, joint_limit[:, 0] + joint_eps, joint_limit[:, 1] - joint_eps
        )[optimizer.idx_pin2target]
        return random_qpos, init_qpos

    @staticmethod
    def compute_pin_qpos(optimizer: Optimizer, qpos: np.ndarray, fixed_qpos: np.ndarray):
        adaptor = optimizer.adaptor
        full_qpos = np.zeros(optimizer.robot.model.nq)
        full_qpos[optimizer.idx_pin2target] = qpos
        full_qpos[optimizer.idx_pin2fixed] = fixed_qpos
        if adaptor is not None:
            full_qpos = adaptor.forward_qpos(full_qpos)
        return full_qpos

    @staticmethod
    def generate_vector_retargeting_data_gt(robot: RobotWrapper, optimizer: VectorOptimizer):
        random_pin_qpos, init_qpos = TestOptimizer.sample_qpos(optimizer)
        robot.compute_forward_kinematics(random_pin_qpos)
        random_pos = np.array([robot.get_link_pose(i)[:3, 3] for i in optimizer.computed_link_indices])
        origin_pos = random_pos[optimizer.origin_link_indices]
        task_pos = random_pos[optimizer.task_link_indices]
        random_target_vector = task_pos - origin_pos

        return random_pin_qpos, init_qpos, random_target_vector

    @staticmethod
    def generate_position_retargeting_data_gt(robot: RobotWrapper, optimizer: PositionOptimizer):
        random_pin_qpos, init_qpos = TestOptimizer.sample_qpos(optimizer)
        robot.compute_forward_kinematics(random_pin_qpos)
        random_target_pos = np.array([robot.get_link_pose(i)[:3, 3] for i in optimizer.target_link_indices])

        return random_pin_qpos, init_qpos, random_target_pos

    def bimanual_position_optimizer(self, left_target_pos, right_target_pos):
        left_computed_qpos, self.left_last_qpos = self.position_optimizer(left_target_pos, self.left_hand_type, self.left_last_qpos)
        right_computed_qpos, self.right_last_qpos = self.position_optimizer(right_target_pos, self.right_hand_type, self.right_last_qpos)
        return left_computed_qpos, right_computed_qpos

    def position_optimizer(self, target_pos, hand_type, last_qpos):
        config_path = get_default_config_path(self.robot_name, RetargetingType.position, hand_type)

        # Note: The parameters below are adjusted solely for this test
        # The smoothness penalty is deactivated here, meaning no low pass filter and no continuous joint value
        # This is because the test is focused solely on the efficiency of single step optimization
        override = dict()
        override["normal_delta"] = 0
        config = RetargetingConfig.load_from_file(config_path, override)

        retargeting = config.build()

        robot: RobotWrapper = retargeting.optimizer.robot
        optimizer = retargeting.optimizer

        tic = time()
        errors = dict(pos=[], joint=[])
        np.random.seed(1)
        fixed_qpos = []
        
        target_pos = target_pos.cpu().numpy()
        
        # Sampled random position
        computed_qpos = optimizer.retarget(target_pos, fixed_qpos=fixed_qpos, last_qpos=last_qpos)

        # Check results
        robot.compute_forward_kinematics(self.compute_pin_qpos(optimizer, computed_qpos, fixed_qpos))
        computed_target_pos = np.array([robot.get_link_pose(i)[:3, 3] for i in optimizer.target_link_indices])

        # Position difference
        error = np.mean(np.linalg.norm(computed_target_pos - target_pos, axis=1))
        errors["pos"].append(error)

        tac = time()
        print(f"Mean optimization position error: {np.mean(errors['pos'])}")
        print(f"Retargeting computation for {self.robot_name.name} takes {tac - tic}s for 1 times")
        # assert np.mean(errors["pos"]) < 1e-2
        
        last_qpos[:] = computed_qpos

        return computed_qpos, last_qpos