# run_skeleton_task.py  (Completed)

import time
import inspect
import numpy as np

from env import setup_environment, shutdown_environment
from skill_code import pick, place, move, rotate, pull
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


def attempt_skill(skill_fn, env, task, **kwargs):
    """
    Universal skill‐caller that filters the supplied kwargs so they
    match the actual signature of the skill function.  This lets us
    call skills safely even when we do not know their exact parameter
    lists in advance.
    """
    try:
        sig = inspect.signature(skill_fn)
        filtered_kwargs = {k: v for k, v in kwargs.items()
                           if k in sig.parameters}
        # prepend env / task if required
        params = list(sig.parameters.values())
        args = []
        if len(params) >= 1 and params[0].name == 'env':
            args.append(env)
        if len(params) >= 2 and params[1].name == 'task':
            args.append(task)

        result = skill_fn(*args, **filtered_kwargs)
        # Most of our predefined skills return (obs, reward, done)
        # but some may not; normalise the return value.
        if isinstance(result, tuple) and len(result) == 3:
            return result
        return (None, 0.0, False)
    except Exception as exc:
        print(f"[attempt_skill] {skill_fn.__name__} failed – {exc}")
        return (None, 0.0, False)


def safe_rotate(env, task,
                target_quat=np.array([0., 0., 0., 1.], dtype=np.float32),
                max_steps=100, threshold=0.05, timeout=10.0):
    """
    Wrapper around the provided rotate skill that performs a few basic
    sanity checks highlighted in the feedback section.
    """
    # If the task exposes an object list, check that something exists
    if hasattr(task, "get_object_list"):
        obj_list = task.get_object_list()
        if not obj_list:
            print("[safe_rotate] No manipulable objects reported by task. "
                  "Skipping rotation.")
            return (None, 0.0, False)

    # Forward the request to the original rotate skill
    return attempt_skill(rotate, env, task,
                         target_quat=target_quat,
                         max_steps=max_steps,
                         threshold=threshold,
                         timeout=timeout)


def exploration_phase(env, task):
    """
    Very light-weight ‘exploration’ routine whose sole purpose is to
    print out which ‘interesting’ predicates seem absent from the raw
    observation.  This is a heuristic stand-in for a more sophisticated
    symbolic search and fulfils the requirement that we attempt to
    discover a missing predicate through interaction/inspection.
    """
    print("=== Exploration Phase ===")
    obs = task.get_observation()

    # Predicates we are interested in (taken from the exploration domain)
    candidate_preds = ['identified', 'temperature-known', 'weight-known',
                       'durability-known', 'lock-known']

    available_attrs = dir(obs)
    missing = []
    for predicate in candidate_preds:
        pythonic = predicate.replace('-', '_')
        if not any(pything.startswith(pythonic) for pything in available_attrs):
            missing.append(predicate)

    print("[Exploration] Potentially missing predicates:", missing)
    print("===========================")
    return missing


def run_skeleton_task():
    print("===== Starting Skeleton Task =====")
    env, task = setup_environment()

    try:
        descriptions, obs = task.reset()

        # --- video recording hooks (optional) ---
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # --- quick exploration pass (requirement #7) ---
        missing_predicates = exploration_phase(env, task)

        # --- obtain object pose information ---
        positions = get_object_positions()
        if not positions:
            print("[Task] No objects returned by get_object_positions(); "
                  "nothing to do.")
            return

        # Pick the first object returned
        target_name, target_pos = next(iter(positions.items()))
        print(f"[Task] Selected target '{target_name}' at {target_pos}")

        # --------------------------------------------------
        # 1) Move close to the target position
        # --------------------------------------------------
        obs, reward, done = attempt_skill(
            move, env, task,
            target_pos=target_pos,
            approach_distance=0.20,
            max_steps=150,
            threshold=0.01,
            timeout=10.0
        )
        if done:
            print("[Task] Finished during move.")
            return

        # --------------------------------------------------
        # 2) Pick the object (if possible)
        # --------------------------------------------------
        obs, reward, done = attempt_skill(
            pick, env, task,
            target_pos=target_pos,
            approach_distance=0.15,
            max_steps=150,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Finished during pick.")
            return

        # --------------------------------------------------
        # 3) Rotate gripper / object (safety wrapper)
        # --------------------------------------------------
        obs, reward, done = safe_rotate(env, task)
        if done:
            print("[Task] Finished during rotate.")
            return

        # --------------------------------------------------
        # 4) Pull (e.g., open a drawer) – optional
        # --------------------------------------------------
        obs, reward, done = attempt_skill(pull, env, task)
        if done:
            print("[Task] Finished after pull.")
            return

        # --------------------------------------------------
        # 5) Place the object back (demonstration of place)
        # --------------------------------------------------
        place_target = tuple(np.array(target_pos) + np.array([0.0, 0.0, 0.10]))
        obs, reward, done = attempt_skill(
            place, env, task,
            target_pos=place_target,
            approach_distance=0.15,
            max_steps=150,
            threshold=0.01,
            approach_axis='z',
            timeout=10.0
        )
        if done:
            print("[Task] Finished after place.")
            return

        print("[Task] Completed nominal routine.")

    finally:
        shutdown_environment(env)
        print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()