# run_skeleton_task.py (Completed)

import time
import numpy as np

# Environment / RLBench interaction
from env import setup_environment, shutdown_environment

# Pre–implemented low–level robot skills (NO new primitives are defined here)
from skill_code import move, pick, place, rotate, pull    # pylint: disable=unused-import

# Video utilities (optional – they do nothing if recording is disabled)
from video import init_video_writers, recording_step, recording_get_observation

# Utility that returns a dictionary {object_name : np.ndarray([x, y, z])}
from object_positions import get_object_positions


def _safe_skill_call(skill_fn, *skill_args, **skill_kwargs):
    """
    Simple wrapper that prints the exception (if any) but keeps the whole
    pipeline alive.  All predefined skills are expected to return
    (observation, reward, done) – but we also support the case in which
    the skill returns nothing, by normalising the output.
    """
    try:
        ret = skill_fn(*skill_args, **skill_kwargs)
    except Exception as exc:      # pylint: disable=broad-except
        print(f"[Warning] Skill <{skill_fn.__name__}> raised an exception: {exc}")
        return None, 0.0, False

    # Normalise the output so that caller can always unpack
    if ret is None:
        return None, 0.0, False
    if isinstance(ret, tuple) and len(ret) == 3:
        return ret
    # Any other return shape → wrap into the canonical (obs, reward, done)
    return ret, 0.0, False


def exploration_phase(env, task, positions):
    """
    Very light-weight open-loop exploration whose only purpose is to
    perform *something* with every object so that:
      • the predefined ‘move’, ‘pick’, ‘place’, ‘rotate’, ‘pull’ skills
        are executed at least once (helping us to discover which
        predicates or effects might be missing in the PDDL domain);
      • eventual missing predicate information can be inferred by an
        external planner in subsequent iterations.
    The routine is deliberately generic — it makes no assumption about
    the concrete task the benchmark provides.
    """
    print("===== [Exploration] Phase START =====")
    # Hand-tuned ‘resting’ deposit location (slightly in front of the robot).
    # If the environment already contains a location with that name, we will
    # overwrite the dictionary later on, otherwise we keep this fallback.
    deposit_pos = np.array([0.50, 0.00, 0.90])

    if "deposit" in positions:
        deposit_pos = positions["deposit"]

    # Iterate over every known object and try a naïve Pick-and-Place
    for obj_name, obj_pos in positions.items():
        print(f"[Exploration]  • Processing object <{obj_name}> @ {obj_pos}")

        # ----------------------------------------------------------
        # 1) MOVE close to the object
        # ----------------------------------------------------------
        obs, reward, done = _safe_skill_call(
            move,
            env,
            task,
            target_pos=obj_pos,
            approach_distance=0.15,
            max_steps=150,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Exploration] Episode terminated during MOVE.")
            return

        # ----------------------------------------------------------
        # 2) PICK the object (if possible)
        # ----------------------------------------------------------
        obs, reward, done = _safe_skill_call(
            pick,
            env,
            task,
            target_pos=obj_pos,
            approach_distance=0.05,
            max_steps=150,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Exploration] Episode terminated during PICK.")
            return

        # ----------------------------------------------------------
        # 3) Optionally test ROTATE (use a 90° rotation around Z)
        # ----------------------------------------------------------
        z_90deg_quat = np.array([0.0, 0.0, np.sin(np.pi/4), np.cos(np.pi/4)])
        obs, reward, done = _safe_skill_call(
            rotate,
            env,
            task,
            target_quat=z_90deg_quat,
            max_steps=75,
            threshold=0.05,
            timeout=5.0
        )
        if done:
            print("[Exploration] Episode terminated during ROTATE.")
            return

        # ----------------------------------------------------------
        # 4) PLACE the object back to the (shared) deposit position
        # ----------------------------------------------------------
        obs, reward, done = _safe_skill_call(
            place,
            env,
            task,
            target_pos=deposit_pos,
            approach_distance=0.10,
            max_steps=150,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Exploration] Episode terminated during PLACE.")
            return

        # ----------------------------------------------------------
        # 5) (Optional) Try PULL – only makes sense for drawers / handles
        # ----------------------------------------------------------
        obs, reward, done = _safe_skill_call(
            pull,
            env,
            task
        )
        if done:
            print("[Exploration] Episode terminated during PULL.")
            return

    print("===== [Exploration] Phase END =====")


def run_skeleton_task():
    """
    Entrypoint for executing the generic skeleton task.  Apart from an
    initial ‘exploration’ pass (see function above), no task-specific
    policy / plan is executed; users can extend or replace the call to
    ‘exploration_phase’ with their own oracle plan if desired.
    """
    print("===== Starting Skeleton Task =====")

    # ------------------------------------------------------------------
    # 1) Environment setup
    # ------------------------------------------------------------------
    env, task = setup_environment()

    try:
        # Reset the task and obtain the very first observation
        descriptions, initial_obs = task.reset()

        # ------------------------------------------------------------------
        # 2) (Optional) Video recording tools
        # ------------------------------------------------------------------
        init_video_writers(initial_obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        # 3) Retrieve the current positions of every relevant object
        # ------------------------------------------------------------------
        positions = get_object_positions()
        if positions is None or len(positions) == 0:
            print("[Warning] get_object_positions() returned an empty result. "
                  "Proceeding with an empty exploration loop.")
            positions = {}

        # ------------------------------------------------------------------
        # 4) Generic exploration to discover missing predicates / debug info
        # ------------------------------------------------------------------
        exploration_phase(env, task, positions)

    finally:
        # ------------------------------------------------------------------
        # 5) Always shut the environment down, even if something goes wrong
        # ------------------------------------------------------------------
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()