import numpy as np
from scipy.spatial.transform import Rotation as R
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor
import time

def normalize_quaternion(quat):
    return quat / np.linalg.norm(quat)

def euler_from_quat(quat_xyzw, seq='xyz'):
    rot = R.from_quat(quat_xyzw)
    return rot.as_euler(seq, degrees=True)

def pick(env, task, target_pos, approach_distance=0.15, max_steps=100, threshold=0.01, approach_axis='z', timeout=10.0):
    print("========== [pick] START ==========")
    obs = task.get_observation()
    start_pos = obs.gripper_pose[:3]
    start_quat = normalize_quaternion(obs.gripper_pose[3:7])
    print(f"[pick] Initial gripper pos: {start_pos}")
    print(f"[pick] Initial gripper quat(xyzw): {start_quat}, euler={euler_from_quat(start_quat)}")
    approach_dir_map = {
        'z': np.array([0, 0, 1]), '-z': np.array([0, 0, -1]),
        'x': np.array([1, 0, 0]), '-x': np.array([-1, 0, 0]),
        'y': np.array([0, 1, 0]), '-y': np.array([0, -1, 0]),
    }
    approach_dir = approach_dir_map[approach_axis]
    approach_pos = target_pos + approach_dir * approach_distance
    print(f"[pick] target_pos: {target_pos}")
    print(f"[pick] approach_pos: {approach_pos} (approach_axis: {approach_axis})")
    waypoints = []
    NUM_WAYPOINT_1, NUM_WAYPOINT_2 = 5, 10
    for t in range(NUM_WAYPOINT_1):
        alpha = t / (NUM_WAYPOINT_1 - 1) if NUM_WAYPOINT_1 > 1 else 1
        wp = start_pos + alpha * (approach_pos - start_pos)
        waypoints.append(wp)
    for t in range(NUM_WAYPOINT_2):
        alpha = t / (NUM_WAYPOINT_2 - 1) if NUM_WAYPOINT_2 > 1 else 1
        wp = approach_pos + alpha * (target_pos - approach_pos)
        waypoints.append(wp)
    action = np.zeros(env.action_shape)
    action[-1] = 1.0
    for i, waypoint in enumerate(waypoints):
        step_count = 0
        start_time = time.time()
        while step_count < max_steps:
            obs = task.get_observation()
            current_pos = obs.gripper_pose[:3]
            dist_to_wp = np.linalg.norm(waypoint - current_pos)
            if dist_to_wp < threshold:
                print(f"      -> Reached waypoint {i}.")
                break
            action[:3] = waypoint
            action[3:7] = start_quat
            obs, reward, done = task.step(action)
            step_count += 1
            if done:
                print("[pick] Task ended during movement!")
                return obs, reward, done
            if time.time() - start_time > timeout:
                print(f"[pick] Timeout: Failed to reach waypoint {i} within {timeout} seconds.")
                from env import shutdown_environment
                shutdown_environment(env)
                raise RuntimeError("Task failed due to timeout.")
    for _ in range(5):
        action[3:7] = start_quat
        obs, reward, done = task.step(action)
        if done:
            return obs, reward, done
    print("[pick] Closing gripper...")
    action[-1] = -1.0
    for _ in range(10):
        action[3:7] = start_quat
        obs, reward, done = task.step(action)
        if done:
            return obs, reward, done
    print("[pick] Done pick process.")
    return obs, reward, done

def place(env, task, target_pos, approach_distance=0.15, max_steps=100, threshold=0.01, approach_axis='z', timeout=10.0):
    print("========== [place] START ==========")
    obs = task.get_observation()
    start_pos = obs.gripper_pose[:3]
    start_quat = normalize_quaternion(obs.gripper_pose[3:7])
    print(f"[place] Initial gripper pos: {start_pos}")
    print(f"[place] Initial gripper quat(xyzw): {start_quat}, euler={euler_from_quat(start_quat)}")
    approach_dir_map = {
        'z': np.array([0, 0, 1]), '-z': np.array([0, 0, -1]),
        'x': np.array([1, 0, 0]), '-x': np.array([-1, 0, 0]),
        'y': np.array([0, 1, 0]), '-y': np.array([0, -1, 0]),
    }
    approach_dir = approach_dir_map[approach_axis]
    approach_pos = target_pos + approach_dir * approach_distance
    print(f"[place] target_pos: {target_pos}")
    print(f"[place] approach_pos: {approach_pos} (approach_axis: {approach_axis})")
    waypoints = []
    NUM_WAYPOINT_1, NUM_WAYPOINT_2 = 5, 10
    for t in range(NUM_WAYPOINT_1):
        alpha = t / (NUM_WAYPOINT_1 - 1) if NUM_WAYPOINT_1 > 1 else 1
        wp = start_pos + alpha * (approach_pos - start_pos)
        waypoints.append(wp)
    for t in range(NUM_WAYPOINT_2):
        alpha = t / (NUM_WAYPOINT_2 - 1) if NUM_WAYPOINT_2 > 1 else 1
        wp = approach_pos + alpha * (target_pos - approach_pos)
        waypoints.append(wp)
    action = np.zeros(env.action_shape)
    action[-1] = -1.0
    for i, waypoint in enumerate(waypoints):
        step_count = 0
        start_time = time.time()
        while step_count < max_steps:
            obs = task.get_observation()
            current_pos = obs.gripper_pose[:3]
            dist_to_wp = np.linalg.norm(waypoint - current_pos)
            if dist_to_wp < threshold:
                print(f"      -> Reached waypoint {i}.")
                break
            action[:3] = waypoint
            action[3:7] = start_quat
            obs, reward, done = task.step(action)
            step_count += 1
            if done:
                print("[place] Task ended during movement!")
                return obs, reward, done
            if time.time() - start_time > timeout:
                print(f"[place] Timeout: Failed to reach waypoint {i} within {timeout} seconds.")
                from env import shutdown_environment
                shutdown_environment(env)
                raise RuntimeError("Task failed due to timeout.")
    for _ in range(5):
        action[3:7] = start_quat
        obs, reward, done = task.step(action)
        if done:
            return obs, reward, done
    print("[place] Opening gripper...")
    action[-1] = 1.0
    for _ in range(10):
        action[3:7] = start_quat
        obs, reward, done = task.step(action)
        if done:
            return obs, reward, done
    print("[place] Done place process.")
    return obs, reward, done

def move(env, task, target_pos, max_steps=100, threshold=0.01, timeout=10.0):
    """
    Move the gripper to a target position while maintaining current orientation and gripper state.
    
    Args:
        env: The environment object providing action shape and other properties.
        task: The task object providing observations and step functionality.
        target_pos (np.ndarray): Target position [x, y, z] to move the gripper to.
        max_steps (int): Maximum steps to reach each waypoint (default: 100).
        threshold (float): Distance threshold to consider a waypoint reached (default: 0.01).
        timeout (float): Maximum time in seconds to reach a waypoint before timeout (default: 10.0).
    
    Returns:
        tuple: (obs, reward, done) from the final step.
    """
    print("========== [move] START ==========")
    
    # Get initial observation
    obs = task.get_observation()
    start_pos = obs.gripper_pose[:3]
    start_quat = normalize_quaternion(obs.gripper_pose[3:7])
    
    # Assume gripper_openness exists in observation (1.0 for open, -1.0 for closed)
    # Note: This assumption is based on typical robotic environments; adjust if different
    initial_gripper_state = getattr(obs, 'gripper_openness', 1.0)  # Default to 1.0 if not found
    
    print(f"[move] Initial gripper pos: {start_pos}")
    print(f"[move] Initial gripper quat(xyzw): {start_quat}, euler={euler_from_quat(start_quat)}")
    print(f"[move] Target pos: {target_pos}")
    
    # Create waypoints from start to target position
    NUM_WAYPOINT = 10
    waypoints = []
    for t in range(NUM_WAYPOINT):
        alpha = t / (NUM_WAYPOINT - 1) if NUM_WAYPOINT > 1 else 1
        wp = start_pos + alpha * (target_pos - start_pos)
        waypoints.append(wp)
    
    # Initialize action array
    action = np.zeros(env.action_shape)
    action[-1] = initial_gripper_state  # Maintain current gripper state
    
    # Move to each waypoint
    for i, waypoint in enumerate(waypoints):
        step_count = 0
        start_time = time.time()
        while step_count < max_steps:
            obs = task.get_observation()
            current_pos = obs.gripper_pose[:3]
            dist_to_wp = np.linalg.norm(waypoint - current_pos)
            if dist_to_wp < threshold:
                print(f"      -> Reached waypoint {i}.")
                break
            action[:3] = waypoint
            action[3:7] = start_quat
            obs, reward, done = task.step(action)
            step_count += 1
            if done:
                print("[move] Task ended during movement!")
                return obs, reward, done
            if time.time() - start_time > timeout:
                print(f"[move] Timeout: Failed to reach waypoint {i} within {timeout} seconds.")
                from env import shutdown_environment
                shutdown_environment(env)
                raise RuntimeError("Task failed due to timeout.")
    
    # Hold at target position for stability
    for _ in range(5):
        action[:3] = target_pos
        action[3:7] = start_quat
        obs, reward, done = task.step(action)
        if done:
            return obs, reward, done
    
    print("[move] Done move process.")
    return obs, reward, done

def rotate(env, task, target_quat, max_steps=100, threshold=0.05, timeout=10.0):
    print("========== [rotate] START ==========")
    
    obs = task.get_observation()
    start_pos = obs.gripper_pose[:3]
    start_quat = normalize_quaternion(obs.gripper_pose[3:7])
    target_quat = normalize_quaternion(target_quat)
    initial_gripper_state = getattr(obs, 'gripper_openness', -1.0)
    
    print(f"[rotate] Initial gripper pos: {start_pos}")
    print(f"[rotate] Initial gripper quat(xyzw): {start_quat}, euler={euler_from_quat(start_quat)}")
    print(f"[rotate] Target quat(xyzw): {target_quat}, euler={euler_from_quat(target_quat)}")
    
    action = np.zeros(env.action_shape)
    action[-1] = initial_gripper_state
    
    step_count = 0
    start_time = time.time()
    
    while step_count < max_steps:
        obs = task.get_observation()
        current_quat = normalize_quaternion(obs.gripper_pose[3:7])
        dot_product = np.dot(current_quat, target_quat)
        if dot_product < 0:
            dot_product = -dot_product
        angle_diff = 2 * np.arccos(np.clip(dot_product, -1.0, 1.0))
        
        print(f"[rotate] Step {step_count}: Angle diff = {angle_diff:.3f} rad")
        
        if angle_diff < threshold:
            print(f"[rotate] Reached target orientation (angle diff: {angle_diff:.3f} rad).")
            break
        
        action[:3] = start_pos
        action[3:7] = target_quat
        obs, reward, done = task.step(action)
        step_count += 1
        
        if done:
            print("[rotate] Task ended during rotation!")
            return obs, reward, done
        
        if time.time() - start_time > timeout:
            print(f"[rotate] Warning: Exceeded timeout of {timeout} seconds, but continuing...")
            break
    
    for _ in range(5):
        action[:3] = start_pos
        action[3:7] = target_quat
        obs, reward, done = task.step(action)
        if done:
            return obs, reward, done
    
    print("[rotate] Done rotate process.")
    return obs, reward, done

def euler_from_quat(quat_xyzw, seq='xyz'):
    rot = R.from_quat(quat_xyzw)
    return rot.as_euler(seq, degrees=True)

def pull(env, task, pull_distance, pull_axis='x', max_steps=100, threshold=0.01, timeout=10.0):
    """
    Pull the gripper along a specified axis by a given distance.
    
    Args:
        env: The environment object.
        task: The task object.
        pull_distance (float): Distance to pull the gripper.
        pull_axis (str): Axis to pull along ('x', '-x', 'y', '-y', 'z', '-z').
        max_steps (int): Maximum steps to complete the pull.
        threshold (float): Distance threshold to consider pull complete.
        timeout (float): Maximum time in seconds to complete the pull.
    
    Returns:
        tuple: (obs, reward, done) from the final step.
    """
    print("========== [pull] START ==========")
    
    obs = task.get_observation()
    start_pos = obs.gripper_pose[:3]
    start_quat = normalize_quaternion(obs.gripper_pose[3:7])
    initial_gripper_state = getattr(obs, 'gripper_openness', -1.0)
    
    print(f"[pull] Initial gripper pos: {start_pos}")
    print(f"[pull] Initial gripper quat(xyzw): {start_quat}, euler={euler_from_quat(start_quat)}")
    
    axis_map = {
        'x': np.array([1, 0, 0]), '-x': np.array([-1, 0, 0]),
        'y': np.array([0, 1, 0]), '-y': np.array([0, -1, 0]),
        'z': np.array([0, 0, 1]), '-z': np.array([0, 0, -1])
    }
    pull_dir = axis_map[pull_axis]
    target_pos = start_pos + pull_dir * pull_distance
    
    print(f"[pull] Target pos: {target_pos} (pull_axis: {pull_axis}, distance: {pull_distance})")
    
    NUM_WAYPOINT = 10
    waypoints = []
    for t in range(NUM_WAYPOINT):
        alpha = t / (NUM_WAYPOINT - 1) if NUM_WAYPOINT > 1 else 1
        wp = start_pos + alpha * (target_pos - start_pos)
        waypoints.append(wp)
    
    action = np.zeros(env.action_shape)
    action[-1] = initial_gripper_state
    
    for i, waypoint in enumerate(waypoints):
        step_count = 0
        start_time = time.time()
        while step_count < max_steps:
            obs = task.get_observation()
            current_pos = obs.gripper_pose[:3]
            dist_to_wp = np.linalg.norm(waypoint - current_pos)
            if dist_to_wp < threshold:
                print(f"      -> Reached waypoint {i}.")
                break
            action[:3] = waypoint
            action[3:7] = start_quat
            obs, reward, done = task.step(action)
            step_count += 1
            if done:
                print("[pull] Task ended during pull!")
                return obs, reward, done
            if time.time() - start_time > timeout:
                print(f"[pull] Timeout: Failed to reach waypoint {i} within {timeout} seconds.")
                from env import shutdown_environment
                shutdown_environment(env)
                raise RuntimeError("Pull failed due to timeout.")
    
    for _ in range(5):
        action[:3] = target_pos
        action[3:7] = start_quat
        obs, reward, done = task.step(action)
        if done:
            return obs, reward, done
    
    print("[pull] Done pull process.")
    return obs, reward, done

