# run_skeleton_task.py (Completed)

import os
import re
import time
import numpy as np

from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *          # <- predefined skills :  pick, place, move, rotate, pull
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ----------------------------------------------------------
# Utility: very small PDDL predicate-extractor
# ----------------------------------------------------------
_PRED_TOKEN_RGX = re.compile(r'\([ \t]*([a-zA-Z0-9_\-]+)')

def _collect_predicates_from_file(path: str) -> set:
    if not os.path.exists(path):
        print(f"[Exploration] File does not exist: {path}")
        return set()
    with open(path, 'r', encoding='utf-8') as fp:
        txt = fp.read()
    return set(_PRED_TOKEN_RGX.findall(txt))


def detect_missing_predicates(domain_path: str, problem_path: str):
    """
    Very light-weight exploration helper:
      – reads the PDDL domain & problem files,
      – tokenises every ‘( …’ parenthesis,
      – extracts first symbol => candidate predicate,
      – returns predicates that appear in the problem but not
        declared in the domain.
    """
    domain_preds = _collect_predicates_from_file(domain_path)
    problem_preds = _collect_predicates_from_file(problem_path)

    # heuristic: filter obvious non-predicate tokens (‘define’, ‘problem’ …)
    blacklist = {
        'define', 'problem', 'domain', ':requirements', ':types',
        ':objects', ':init', ':goal', ':predicates', ':action',
        ':parameters', ':precondition', ':effect', 'and', 'not',
        'forall', 'when', 'either', ':typing', ':strips', ':negative-preconditions',
        ':conditional-effects', ':equality', ':universal-preconditions',
    }
    domain_preds.difference_update(blacklist)
    problem_preds.difference_update(blacklist)

    missing = sorted(problem_preds - domain_preds)
    return missing


# ----------------------------------------------------------
# High-level helper for opening (pulling) every drawer
# ----------------------------------------------------------
def try_open_all_drawers(env, task, positions: dict):
    """
    Very conservative & generic routine that attempts to:
      1) rotate gripper to 'upright' (identity) quaternion,
      2) move to every object whose name hints it is a drawer,
      3) pick the object (drawer handle),
      4) pull (open) the drawer,
      5) place the handle back (if 'place' skill exists).

    All skill executions are shielded with try/except so that
    a failure on one drawer does not abort the entire task.
    """
    print("========== [Task] Opening all drawers ==========")

    # heuristic target orientation (identity quaternion)
    identity_quat = np.array([0., 0., 0., 1.], dtype=np.float32)

    for obj_name, pos in positions.items():
        if 'drawer' not in obj_name and 'handle' not in obj_name:
            continue

        print(f"\n[Task] ---- Handling {obj_name} ----")

        # 1) rotate (ensure we start each attempt with same orientation)
        try:
            obs, reward, done = rotate(
                env,
                task,
                target_quat=identity_quat,
                max_steps=120,
                threshold=0.05,
                timeout=5.0
            )
            if done:
                print("[Task] Episode terminated during rotate!")
                return
        except Exception as exc:
            print(f"[Task]  rotate() failed on {obj_name}: {exc}")

        # 2) approach / move
        try:
            print(f"[Task]  Moving to {obj_name} @ {pos}")
            obs, reward, done = move(       # signature assumed by skill_code
                env,
                task,
                target_pos=pos,
                approach_distance=0.10,
                max_steps=150,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Task] Episode terminated during move!")
                return
        except Exception as exc:
            print(f"[Task]  move() failed on {obj_name}: {exc}")

        # 3) pick
        try:
            print("[Task]  Picking the drawer handle")
            obs, reward, done = pick(
                env,
                task,
                target_pos=pos,
                approach_distance=0.03,
                max_steps=120,
                threshold=0.005,
                approach_axis='z',
                timeout=8.0
            )
            if done:
                print("[Task] Episode terminated during pick!")
                return
        except Exception as exc:
            print(f"[Task]  pick() failed on {obj_name}: {exc}")
            continue      # skip pull attempt

        # 4) pull
        try:
            print("[Task]  Pulling / Opening the drawer")
            obs, reward, done = pull(env, task)     # pull skill usually doesn’t need arguments
            if done:
                print("[Task] Episode terminated during pull!")
                return
        except Exception as exc:
            print(f"[Task]  pull() failed on {obj_name}: {exc}")

        # 5) place (optional)
        if 'place' in globals():
            try:
                print("[Task]  Placing back the handle (releasing)")
                obs, reward, done = place(
                    env,
                    task,
                    target_pos=pos,
                    approach_distance=0.05,
                    max_steps=80,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=6.0
                )
                if done:
                    print("[Task] Episode terminated during place!")
                    return
            except Exception as exc:
                print(f"[Task]  place() failed on {obj_name}: {exc}")

    print("\n========== [Task] Finished opening drawers ==========")


# ----------------------------------------------------------
# Main entry point
# ----------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # --------------------------------------------------
    # 1) VERY LIGHT STATIC EXPLORATION on PDDL files
    # --------------------------------------------------
    default_domain_path  = "INPUT_YOUR_PATH"
    default_problem_path = "INPUT_YOUR_PATH"

    missing_preds = detect_missing_predicates(default_domain_path, default_problem_path)
    if missing_preds:
        print("[Exploration] Missing predicates detected (present in problem, absent from domain):")
        for p in missing_preds:
            print(f"   - {p}")
    else:
        print("[Exploration] No missing predicates detected (or files unavailable).")

    # --------------------------------------------------
    # 2) RLBench Environment set-up
    # --------------------------------------------------
    env, task = setup_environment()

    try:
        descriptions, obs = task.reset()
        # --- Video initialisation --
        init_video_writers(obs)

        # --- wrap step() / get_observation() for recording ---
        task.step            = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # --------------------------------------------------
        # 3) Retrieve object positions (user-provided util)
        # --------------------------------------------------
        positions = get_object_positions()
        if not positions:
            print("[Warning] object_positions.get_object_positions() returned empty dict")

        # --------------------------------------------------
        # 4) Execute high-level plan (open every drawer)
        # --------------------------------------------------
        try_open_all_drawers(env, task, positions)

    finally:
        # --------------------------------------------------
        # 5) Always cleanly shut down the environment
        # --------------------------------------------------
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()