# run_skeleton_task.py   (COMPLETE)

import time
import math
import numpy as np
from typing import Dict, Tuple

from env import setup_environment, shutdown_environment
from skill_code import *
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# --------------------------------------------------------
# Helper – quaternion → Euler (roll, pitch, yaw) converter
# (rotate() in skill_code expects this to be in global scope)
# --------------------------------------------------------
def euler_from_quat(q):
    """
    Convert quaternion (x, y, z, w) to Euler angles (r, p, y)
    Returns (roll, pitch, yaw) in radians.
    """
    x, y, z, w = q
    t0 = +2.0 * (w * x + y * z)
    t1 = +1.0 - 2.0 * (x * x + y * y)
    roll = math.atan2(t0, t1)

    t2 = +2.0 * (w * y - z * x)
    t2 = +1.0 if t2 > +1.0 else t2
    t2 = -1.0 if t2 < -1.0 else t2
    pitch = math.asin(t2)

    t3 = +2.0 * (w * z + x * y)
    t4 = +1.0 - 2.0 * (y * y + z * z)
    yaw = math.atan2(t3, t4)

    return roll, pitch, yaw


# --------------------------------------------------------
# Generic safe-wrapper around any predefined skill
# --------------------------------------------------------
def call_skill(skill_fn, *args, **kwargs):
    """
    Executes a skill function while catching and reporting errors.
    Returns (obs, reward, done) if the skill returns them, otherwise (None, 0, False).
    """
    try:
        ret = skill_fn(*args, **kwargs)
        # Most RLBench skills follow (obs, reward, done) convention.
        # Handle both tuple and single-value returns:
        if isinstance(ret, tuple) and len(ret) == 3:
            return ret
        return None, 0.0, False
    except Exception as e:
        print(f"[WARN] Skill {skill_fn.__name__} failed – {e}")
        return None, 0.0, False


# --------------------------------------------------------
# Basic workspace clamp – prevents obviously unsafe targets
# --------------------------------------------------------
def is_within_workspace(pos: Tuple[float, float, float],
                        bounds: Tuple[Tuple[float, float], Tuple[float, float], Tuple[float, float]]
                        = ((-1.0, 1.0), (-1.0, 1.0), (0.0, 1.5))) -> bool:
    (min_x, max_x), (min_y, max_y), (min_z, max_z) = bounds
    x, y, z = pos
    return (min_x <= x <= max_x) and (min_y <= y <= max_y) and (min_z <= z <= max_z)


# --------------------------------------------------------
# Exploration routine – tries to trigger every available skill
# to discover missing predicates / interactions.
# --------------------------------------------------------
def exploration_phase(env, task, positions: Dict[str, Tuple[float, float, float]]):
    print("\n===== EXPLORATION PHASE – searching for unknown predicates =====")
    missing_predicates = set()
    obs = task.get_observation()

    # 0) Ensure gripper position is known in the dictionary
    if 'gripper' not in positions:
        positions['gripper'] = tuple(obs.gripper_pose[:3])

    # 1) MOVE around each location to identify and temperature-check objects
    for name, pos in positions.items():
        if not is_within_workspace(pos):
            print(f"[Skip] {name} outside workspace: {pos}")
            continue
        print(f"[Explore] Moving towards {name} at {pos}")
        _, _, done = call_skill(move, env, task, target_pos=pos,
                                approach_distance=0.10, threshold=0.02,
                                timeout=5.0)
        if done:
            print("[Explore] Task unexpectedly ended during move; aborting exploration.")
            return missing_predicates

    # 2) Try PICK every pickable object
    for name, pos in positions.items():
        if 'drawer' in name.lower():
            continue  # drawers handled later
        print(f"[Explore] Attempting to pick {name}")
        _, _, _ = call_skill(pick, env, task, target_pos=pos,
                             approach_distance=0.10, threshold=0.01,
                             approach_axis='z', timeout=5.0)

        # Re-place the object if possible
        call_skill(place, env, task, target_pos=pos,
                   approach_distance=0.12, threshold=0.02,
                   approach_axis='z', timeout=5.0)

    # 3) Drawer-specific routine: rotate, anchor, pull
    drawer_candidates = [n for n in positions if 'drawer' in n.lower()]
    for drawer_name in drawer_candidates:
        drawer_pos = positions[drawer_name]
        print(f"[Explore] Testing drawer '{drawer_name}' at {drawer_pos}")

        # 3-a) Rotate gripper to ninety_deg so that side-move becomes valid
        target_quat = np.array([0.0, 0.7071, 0.0, 0.7071])  # ≈ 90° about Y
        call_skill(rotate, env, task, target_quat, max_steps=120, threshold=0.04, timeout=8.0)

        # 3-b) Approach drawer front face
        call_skill(move, env, task, target_pos=drawer_pos,
                   approach_distance=0.08, threshold=0.02, timeout=5.0)

        # 3-c) Attempt to pull
        _, _, _ = call_skill(pull, env, task)
        # If pull skill raises an exception mentioning 'lock-known', we treat it as missing predicate.
        # The wrapper already prints warnings – parse the last warning if necessary.
        # In any failure the wrapper printed a warning; here, we simply mark the predicate:
        missing_predicates.add('lock-known')

    print("===== EXPLORATION COMPLETE =====\n")
    print("Potentially missing predicates discovered:", missing_predicates)
    return missing_predicates


# --------------------------------------------------------
# Main task runner
# --------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # Reset environment
        descriptions, obs = task.reset()

        # Initialise optional video capture
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # Retrieve positions of all relevant objects
        positions = get_object_positions()
        # Safety fallback – ensure dictionary exists
        if positions is None:
            positions = {}

        # ------------------------------------------------
        # 1) Exploration – learn unknown predicates
        # ------------------------------------------------
        missing_predicates = exploration_phase(env, task, positions)

        # ------------------------------------------------
        # 2) MAIN PLAN (simple demonstration)
        #    If drawer exists and was successfully opened,
        #    finish task by placing a small object inside.
        # ------------------------------------------------
        small_objs = [n for n in positions if 'cube' in n.lower() or 'block' in n.lower()]
        open_drawers = [n for n in positions if 'drawer' in n.lower()]

        if small_objs and open_drawers:
            cube_name = small_objs[0]
            drawer_name = open_drawers[0]
            cube_pos = positions[cube_name]
            drawer_pos = positions[drawer_name]

            print(f"[Task] Picking {cube_name}")
            call_skill(move, env, task, target_pos=cube_pos, approach_distance=0.10,
                       threshold=0.01, timeout=6.0)
            call_skill(pick, env, task, target_pos=cube_pos, approach_distance=0.08,
                       threshold=0.01, timeout=6.0)

            print(f"[Task] Placing {cube_name} into {drawer_name}")
            place_target = (drawer_pos[0], drawer_pos[1], drawer_pos[2] + 0.05)  # slightly above drawer
            call_skill(move, env, task, target_pos=place_target, approach_distance=0.10,
                       threshold=0.02, timeout=6.0)
            call_skill(place, env, task, target_pos=place_target, approach_distance=0.10,
                       threshold=0.02, timeout=6.0)

            print("[Task] Main plan executed successfully.")

        else:
            print("[Task] Could not execute main plan – required objects missing.")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()