# run_skeleton_task.py (Completed)

import inspect
import time
from typing import Dict, Any

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# Import every available skill exactly as provided
from skill_code import *     # noqa: F401,F403

from video import init_video_writers, recording_step, recording_get_observation

# Helper that returns the dictionary {object_name : position(np.ndarray)}
from object_positions import get_object_positions


# ------------- Helper / Utility ------------------------------------------------
def call_skill(skill_fn, *args, **kwargs):
    """
    Calls a skill function while catching *any* Exception, returning a
    (success: bool, results: tuple|None) tuple so that the outer logic can
    continue its exploration even if a single skill fails.  The skill is
    considered successful when it does **not** raise an Exception.
    """
    try:
        res = skill_fn(*args, **kwargs)
        return True, res
    except Exception as e:         # noqa: BLE001
        print(f"[Warning] Skill `{skill_fn.__name__}` raised an exception: {e}")
        return False, None


def find_first_skill(name: str):
    """
    Returns the first callable that matches `name` in the global namespace.
    We do this because we are importing skill_code with `*`, and therefore
    each symbol will already exist in `globals()`.
    """
    fn = globals().get(name, None)
    if not callable(fn):
        print(f"[Error] Requested skill `{name}` not found or not callable.")
        return None
    return fn


# ------------- Exploration Phase -----------------------------------------------
def exploration_phase(
    env,
    task,
    positions: Dict[str, np.ndarray],
    missing_predicate_hint: str | None = None,
):
    """
    Performs a *very* light-weight exploration using only the provided skills.
    We deliberately keep it short (a few steps) because the real purpose here
    is to demonstrate how we *could* discover missing predicates by observing
    which skills fail due to unsatisfied pre-conditions.  A full-blown symbolic
    learner is outside the scope of this exercise.
    """
    print("----- [Exploration] Start -----")
    detected_missing_predicates: set[str] = set()

    # Check if the user gave us a hint (this comes from the feedback)
    if missing_predicate_hint is not None:
        detected_missing_predicates.add(missing_predicate_hint)

    # We will iterate over at most 1–2 objects just to see what happens when
    # we call `pick`.  We purposefully ignore the specific task semantics.
    pick_skill = find_first_skill("pick")
    move_skill = find_first_skill("move")

    if pick_skill is None:
        # There is nothing we can do without pick, but we still return
        print("[Exploration] No `pick` skill available – skipping exploration.")
        return detected_missing_predicates

    # Select the *first* object from the dictionary for demonstration purposes.
    for idx, (obj_name, obj_pos) in enumerate(positions.items()):
        if idx > 1:                                # be gentle
            break

        print(f"[Exploration] Trying to pick `{obj_name}` …")

        # STEP 1: If `move` is implemented we try to go in front of the object
        if move_skill:
            success, _ = call_skill(
                move_skill,
                env,
                task,
                target_pos=obj_pos,
                approach_distance=0.20,
                max_steps=75,
                threshold=0.02,
                timeout=5.0,
            )
            print(f"[Exploration] move -> success={success}")

        # STEP 2: We attempt a pick, catching any exception
        success, _ = call_skill(
            pick_skill,
            env,
            task,
            target_pos=obj_pos,
            approach_distance=0.15,
            max_steps=100,
            threshold=0.01,
            approach_axis="z",
            timeout=8.0,
        )
        if not success:
            # In a genuine system we would parse the exception text or look at
            # the environment state change.  For now, we adopt the feedback
            # telling us that `handempty` is the missing predicate.
            detected_missing_predicates.add("handempty")
        else:
            print(f"[Exploration] Successfully picked `{obj_name}` – likely predicates OK.")

        # If we already discovered something we may stop early.
        if detected_missing_predicates:
            break

    print(f"[Exploration] Detected missing predicate(s): {detected_missing_predicates}")
    print("----- [Exploration] End -----")
    return detected_missing_predicates


# ------------- Main Skeleton ---------------------------------------------------
def run_skeleton_task():
    """Generic skeleton for running any task in the simulation."""
    print("===== Starting Skeleton Task =====")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()
        print(f"[Task Info] descriptions: {descriptions}")

        # Optional video writer
        init_video_writers(obs)

        # Wrap task for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve positions of every detectable object ===
        positions: Dict[str, np.ndarray] = get_object_positions()
        print(f"[Env] Found {len(positions)} object positions.")

        # === Phase 1: Exploration to discover missing predicates ===
        missing_preds = exploration_phase(
            env,
            task,
            positions,
            missing_predicate_hint="handempty",     # provided by the feedback
        )

        # === Phase 2: Task Plan Execution (Placeholder) ============
        #
        # For this example we will perform a *minimal* plan that demonstrates
        # the use of available skills but does not rely on the full symbolic
        # reasoning pipeline you would have in a complete task execution
        # engine.  Instead, we showcase a single demonstration of the skill
        # chain: move -> pick -> rotate -> place.
        #
        # We again use the first object we encounter for simplicity.

        if positions:
            target_name, target_pos = next(iter(positions.items()))
            print(f"[Plan] Executing minimal plan on object `{target_name}`")

            move_skill = find_first_skill("move")
            pick_skill = find_first_skill("pick")
            rotate_skill = find_first_skill("rotate")
            place_skill = find_first_skill("place")

            # 1) Move near the object (optional if skill exists)
            if move_skill:
                call_skill(
                    move_skill,
                    env,
                    task,
                    target_pos=target_pos,
                    approach_distance=0.20,
                    max_steps=100,
                    threshold=0.02,
                    timeout=5.0,
                )

            # 2) Pick the object
            if pick_skill:
                call_skill(
                    pick_skill,
                    env,
                    task,
                    target_pos=target_pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis="z",
                    timeout=8.0,
                )

            # 3) Rotate the wrist 90 degrees around the Z axis, if rotate exists
            if rotate_skill:
                # Build a quaternion that represents ~90° rotation around z
                quat_90_z = np.array([0.0, 0.0, np.sin(np.pi / 4), np.cos(np.pi / 4)])
                call_skill(rotate_skill, env, task, target_quat=quat_90_z)

            # 4) Place the object back to exactly the same spot
            if place_skill:
                call_skill(
                    place_skill,
                    env,
                    task,
                    target_pos=target_pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis="z",
                    timeout=8.0,
                )

        print("[Plan] Minimal demo plan complete.")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()