# run_skeleton_task.py (Completed)

import traceback
import time
import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# NOTE: every low-level action we use must already exist in skill_code
from skill_code import move, pick, place, rotate, pull

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def _safe_skill_call(func, *args, **kwargs):
    """
    Wrapper that executes a skill‐function and prints the exception
    instead of crashing the whole run.  Returns (obs, reward, done)
    when the wrapped function follows the RLBench convention, or
    (None, None, False) otherwise.
    """
    try:
        ret = func(*args, **kwargs)
        # RLBench-style skills often return (obs, reward, done)
        if isinstance(ret, (list, tuple)) and len(ret) == 3:
            return ret
        return None, None, False
    except Exception as e:          # pylint: disable=broad-except
        print(f"[WARN] Skill {func.__name__} raised: {e}")
        traceback.print_exc()
        return None, None, False


def exploration_phase(env, task):
    """
    Naïve exploration routine whose only purpose is to trigger the
    predicates declared in exploration.pddl (identified, weight-known,
    durability-known, temperature-known, lock-known …).

    We simply iterate over every object position reported by
    object_positions.py, move near it, attempt to pick it, then (if
    the object looks like a drawer‐handle) try a pull.  Every skill
    invocation is protected by _safe_skill_call so that failures do
    not abort the entire exploration run.
    """
    print("===== [Exploration] start =====")
    positions = get_object_positions()
    if not positions:
        print("[Exploration] WARNING: get_object_positions() returned an empty result.")
        return

    robot_home = np.array([0.0, 0.0, 0.6])  # generic safe Z height

    for name, pos in positions.items():
        print(f"[Exploration] Visiting object `{name}` at {pos}")

        # -------------------------------------------------
        # 1) move close to the object (≈ ‘identified’ + ‘temperature-known’)
        # -------------------------------------------------
        _safe_skill_call(
            move,
            env,
            task,
            target_pos=np.array(pos),
            approach_distance=0.15,
            max_steps=150,
            threshold=0.02,
            approach_axis='z',
            timeout=10.0,
        )

        # -------------------------------------------------
        # 2) try to pick => ‘weight-known’ / ‘durability-known’
        # -------------------------------------------------
        obs, reward, done = _safe_skill_call(
            pick,
            env,
            task,
            target_pos=np.array(pos),
            approach_distance=0.05,
            max_steps=100,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0,
        )

        # If something was picked we immediately place it back.
        if done:
            print("[Exploration] Task ended during pick (environment signalled done).")
            return
        if obs is not None:
            _safe_skill_call(
                place,
                env,
                task,
                target_pos=np.array(pos) + np.array([0.0, 0.0, 0.05]),
                approach_distance=0.05,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0,
            )

        # -------------------------------------------------
        # 3) try ‘rotate’ and ‘pull’ in case this is a drawer
        # -------------------------------------------------
        drawer_like_keywords = ["drawer", "handle", "knob"]
        if any(k in name.lower() for k in drawer_like_keywords):
            # a) rotate gripper → ninety_deg
            ninety_quat = np.array([0.0, 0.0, np.sin(np.pi / 4), np.cos(np.pi / 4)])
            _safe_skill_call(
                rotate,
                env,
                task,
                target_quat=ninety_quat,
                max_steps=80,
                threshold=0.05,
                timeout=5.0,
            )

            # b) pull straight backward (skill ‘pull’ should encapsulate how)
            _safe_skill_call(pull, env, task)

        # -------------------------------------------------
        # 4) go back to a safe home height before the next object
        # -------------------------------------------------
        _safe_skill_call(
            move,
            env,
            task,
            target_pos=robot_home,
            approach_distance=0.0,
            max_steps=100,
            threshold=0.02,
            approach_axis='z',
            timeout=10.0,
        )

    print("===== [Exploration] completed =====")


def run_skeleton_task():
    """Generic entry-point for running the task in simulation."""
    print("===== Starting Skeleton Task =====")
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # -----------------------------------------------------------------
        # Optional: start video recording (won’t crash if unavailable)
        # -----------------------------------------------------------------
        try:
            init_video_writers(obs)
            task.step = recording_step(task.step)
            task.get_observation = recording_get_observation(task.get_observation)
            print("[Video] Recording wrappers enabled.")
        except Exception as e:          # pylint: disable=broad-except
            print(f"[Video] Could not initialise video recording: {e}")

        # -----------------------------------------------------------------
        # 1) Exploration – gather hidden predicates
        # -----------------------------------------------------------------
        exploration_phase(env, task)

        # -----------------------------------------------------------------
        # 2) Insert your real “oracle” plan below
        #    For demonstration we simply finish.
        # -----------------------------------------------------------------
        print("[Main] Oracle plan would execute here …")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()