# run_skeleton_task.py (Completed)

import re
import inspect
import numpy as np

from pyrep.objects.shape import Shape                 # keep – skeleton requirement
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# ---- Import every predefined skill exactly as delivered in skill_code.py ----
from skill_code import rotate, pick, place, move, pull

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions


# --------------------------------------------------------------------------- #
# Helpers for simple domain‐string parsing (used in the exploration phase)
# --------------------------------------------------------------------------- #
def _extract_predicates(domain_str: str):
    """
    Very small utility that crawls through a PDDL domain string and returns the
    set of predicate names that appear in ‘(:predicates …)’ blocks.
    We are *not* doing full parsing – only something good enough to discover
    typos / missing names across two domain files.
    """
    # Remove comments
    domain_wo_comments = re.sub(r";.*", "", domain_str)
    # Take the content of every (:predicates …) section
    preds = set()
    for block in re.findall(r"\(:predicates(.+?)\)", domain_wo_comments, flags=re.S):
        # Every atom inside the block starts with an opening parenthesis
        for atom in re.findall(r"\(([^()\s]+)", block):
            preds.add(atom.strip())
    return preds


def _discover_missing_predicates(reference_pred_set, candidate_pred_set):
    """
    Convenience wrapper: returns names that are present in `candidate_pred_set`
    but absent from `reference_pred_set`.
    """
    return sorted(list(candidate_pred_set - reference_pred_set))


# --------------------------------------------------------------------------- #
# Embedded PDDL domain strings – **only** for textual comparison
# (No planning is performed – we just check predicate names.)
# --------------------------------------------------------------------------- #
COMBINED_DOMAIN_STR = """
(define (domain combined-domain)
  (:requirements :strips :typing :negative-preconditions :equality :disjunctive-preconditions)
  (:types object location drawer - object gripper - object position - object angle - object)
  (:predicates
    (at ?obj - object ?loc - location)
    (holding ?obj - object)
    (handempty)
    (is-locked ?d - drawer)
    (is-open ?d - drawer)
    (rotated ?g - gripper ?a - angle)
    (gripper-at ?g - gripper ?p - position)
    (holding-drawer ?g - gripper ?d - drawer)
    (is-side-pos ?p - position ?d - drawer)
    (is-anchor-pos ?p - position ?d - drawer)
  )
)
"""

EXPLORATION_DOMAIN_STR = """
(define (domain exploration)
  (:requirements :strips :typing :conditional-effects :universal-preconditions)
  (:types robot object location)
  (:predicates
    (robot-at ?r - robot ?loc - location)
    (at ?obj - object ?loc - location)
    (identified ?obj - object)
    (temperature-known ?obj - object)
    (holding ?obj - object)
    (handempty)
    (weight-known ?obj - object)
    (durability-known ?obj - object)
    (lock-known ?obj - object)
  )
)
"""


# --------------------------------------------------------------------------- #
# Main routine
# --------------------------------------------------------------------------- #
def run_skeleton_task():
    """Generic skeleton for running any task in the simulation – NOW completed."""
    print("===== Starting Skeleton Task =====")

    # ---------------- Environment Setup ---------------- #
    env, task = setup_environment()
    try:
        # Reset episode / task
        descriptions, obs = task.reset()

        # ---- (Optional) Video capture helpers ----
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # ---------------- Exploration Phase ---------------- #
        print("\n--- [Exploration] Analysing domain files to find missing predicates ---")
        combined_predicates = _extract_predicates(COMBINED_DOMAIN_STR)
        exploration_predicates = _extract_predicates(EXPLORATION_DOMAIN_STR)

        missing_predicates = _discover_missing_predicates(
            reference_pred_set=combined_predicates,
            candidate_pred_set=exploration_predicates,
        )

        if missing_predicates:
            print(f"[Exploration] Predicate(s) present in exploration domain but "
                  f"missing from combined domain:\n    {missing_predicates}")
        else:
            print("[Exploration] No missing predicates detected.")

        # ---------------- Retrieve Object Positions ---------------- #
        positions = get_object_positions()         # dict: name -> (x, y, z)

        print("\n--- [Planning / Acting] Simple demonstration using available skills ---")
        # Very small demo:  1. Rotate gripper to current orientation (safe call)
        print("[Demo] Calling rotate(...) with current gripper quaternion "
              "– serves as basic skill invocation / health-check.")
        current_quat = obs.gripper_pose[3:7]
        rotate(env, task, target_quat=current_quat, max_steps=1)

        # Optional: Iterate over two objects (if they exist) and attempt a pick-and-place
        demo_objects = list(positions.keys())[:2]        # at most two for safety
        if not demo_objects:
            print("[Demo] No objects reported via get_object_positions – skipping pick/place.")
        else:
            for obj_name in demo_objects:
                obj_pos = positions[obj_name]
                print(f"\n[Demo] Attempting to pick & place object '{obj_name}' at {obj_pos}")

                # --- dynamic argument mapping (skill signatures may vary) ---
                try:
                    pick_sig = inspect.signature(pick)
                    kwargs = {}
                    for param in pick_sig.parameters.values():
                        if param.name == "env":
                            kwargs["env"] = env
                        elif param.name == "task":
                            kwargs["task"] = task
                        elif param.name in ("target_pos", "target_position", "pos"):
                            kwargs[param.name] = obj_pos
                        # reasonable fall-backs for common optional parameters
                        elif param.name == "approach_distance":
                            kwargs[param.name] = 0.15
                        elif param.name == "threshold":
                            kwargs[param.name] = 0.02
                        elif param.name == "approach_axis":
                            kwargs[param.name] = "z"
                        elif param.name == "max_steps":
                            kwargs[param.name] = 100
                        elif param.name == "timeout":
                            kwargs[param.name] = 10.0
                    print(f"[Demo] → pick() kwargs: {kwargs}")
                    _ = pick(**kwargs)
                except Exception as e:
                    print(f"[Warning] pick() failed for '{obj_name}': {e}")

                # If pick succeeded, try a very small upward 'place' 2 cm above
                try:
                    place_sig = inspect.signature(place)
                    kwargs = {}
                    for param in place_sig.parameters.values():
                        if param.name == "env":
                            kwargs["env"] = env
                        elif param.name == "task":
                            kwargs["task"] = task
                        elif param.name in ("target_pos", "target_position", "pos"):
                            # move slightly above current pos
                            raised = np.array(obj_pos) + np.array([0.0, 0.0, 0.02])
                            kwargs[param.name] = raised.tolist()
                        elif param.name == "max_steps":
                            kwargs[param.name] = 100
                        elif param.name == "threshold":
                            kwargs[param.name] = 0.02
                        elif param.name == "timeout":
                            kwargs[param.name] = 10.0
                    print(f"[Demo] → place() kwargs: {kwargs}")
                    _ = place(**kwargs)
                except Exception as e:
                    print(f"[Warning] place() failed for '{obj_name}': {e}")

        # ---------------- End of example plan ---------------- #

    finally:
        # Regardless of what happens we shut the simulation down cleanly
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()