# run_skeleton_task.py (Completed Version – with Exploration Phase)

import os
import re
import time
import numpy as np

from pyrep.objects.shape import Shape                 # keep: imported by skeleton
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *                              # keep: we do NOT redefine skills
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
# Helper: very light-weight PDDL parser to collect predicate symbols
# ---------------------------------------------------------------------------
_PREDICATE_RE = re.compile(r'\(\s*([a-zA-Z0-9_\-]+)')   # grabs first symbol in “( …”

def _extract_predicates_from_text(text: str):
    predicates = set()
    in_predicate_block = False
    for line in text.splitlines():
        # begin/end of :predicates block
        if ':predicates' in line:
            in_predicate_block = True
            continue
        if in_predicate_block and line.strip().startswith('(:'):
            break                                           # next section reached
        if in_predicate_block:
            matches = _PREDICATE_RE.findall(line)
            if matches:
                # first token is the predicate symbol itself
                predicates.add(matches[0])
    return predicates


def _read_first_existing(path_candidates):
    """Return (path, text) for the first existing file in path_candidates or (None, '')"""
    for p in path_candidates:
        if os.path.exists(p):
            with open(p, 'r', encoding='utf-8') as f:
                return p, f.read()
    return None, ''


def find_missing_predicates():
    """
    Compare the exploration domain with the task domain and return a list
    of predicates that appear only in the exploration domain.
    """
    task_domain_path, task_text = _read_first_existing(
        ['combined-domain.pddl', 'domain.pddl', './pddl/domain.pddl']
    )
    exploration_path, exploration_text = _read_first_existing(
        ['exploration-domain.pddl', 'exploration.pddl', './pddl/exploration.pddl']
    )

    if not task_text or not exploration_text:
        print('[Exploration] Domain files not found – skipping predicate comparison.')
        return []

    task_preds = _extract_predicates_from_text(task_text)
    exploration_preds = _extract_predicates_from_text(exploration_text)
    missing = sorted(list(exploration_preds - task_preds))

    print(f'[Exploration] Loaded task-domain from   : {task_domain_path}')
    print(f'[Exploration] Loaded exploration domain: {exploration_path}')
    print(f'[Exploration] #task predicates        = {len(task_preds)}')
    print(f'[Exploration] #explore predicates     = {len(exploration_preds)}')
    print(f'[Exploration] Missing predicates      = {missing}')
    return missing


# ---------------------------------------------------------------------------
#   Very small interactive “probe” that tries each available skill once
#   (with dummy / best-guess parameters) just to observe effects and collect
#   extra information that might be required for planning later.
#   This is deliberately conservative – it will bail out quickly on errors.
# ---------------------------------------------------------------------------
def quick_probe(env, task, positions):
    """
    Executes one round of each primitive skill (if possible) to obtain
    additional environment feedback useful for future planning.
    The probe is OPTIONAL – it’s safe to skip if any call fails.
    """
    print('========== [Exploration-Probe] START ==========')

    # Find an arbitrary target position (first entry) for probing.
    target_name, target_pos = (None, None)
    if isinstance(positions, dict) and positions:
        target_name, target_pos = next(iter(positions.items()))

    try:
        # 1) Rotate gripper 90° around Z (if rotate exists)
        if 'rotate' in globals():
            current_obs = task.get_observation()
            quat = current_obs.gripper_pose[3:7]
            # simple 90° rotation around gripper Z axis
            delta_quat = np.array([0.0, 0.0, np.sin(np.pi/4), np.cos(np.pi/4)])
            target_quat = quat * 0.0          # copy shape
            target_quat[:] = quat[:]          # start with current
            target_quat = delta_quat          # overwrite – purely illustrative
            rotate(env, task, target_quat, max_steps=10, threshold=0.25, timeout=2.0)

        # 2) Move slightly (dummy) just to see effect
        if target_pos is not None and 'move' in globals():
            # we just try to “approach” by sending a small delta in world frame
            current_pose = task.get_observation().gripper_pose
            delta = np.array([0.03, 0.0, 0.03])
            action = np.zeros(env.action_shape)
            action[:3] = current_pose[:3] + delta
            action[3:7] = current_pose[3:7]
            action[-1] = getattr(task.get_observation(), 'gripper_openness', -1.0)
            task.step(action)

    except Exception as exc:
        # Any failure in probe should not abort the whole run
        print(f'[Exploration-Probe] Warning – probe aborted: {exc}')

    print('========== [Exploration-Probe] DONE ==========')


# ---------------------------------------------------------------------------
#   Main entry (original skeleton + our additions)
# ---------------------------------------------------------------------------
def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print('===== Starting Skeleton Task =====')

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # === Retrieve Object Positions ===
        try:
            positions = get_object_positions()
        except Exception as exc:
            print(f'[Skeleton] Could not fetch object positions: {exc}')
            positions = {}

        # ------------------------------------------------------------------
        #  Exploration Phase – identify predicates that exist in the
        #  exploration domain but not in the task domain.
        # ------------------------------------------------------------------
        missing_predicates = find_missing_predicates()

        # OPTIONAL probe to physically interact and gather runtime feedback
        quick_probe(env, task, positions)

        # ------------------------------------------------------------------
        #  At this point you would normally dispatch the PDDL planner
        #  (Fast-Downward, etc.) with an updated domain that contains the
        #  newly discovered predicates, then execute the resulting plan by
        #  calling the predefined skills in correct order.
        #
        #  For this generic template we simply print out what we learned.
        # ------------------------------------------------------------------
        if missing_predicates:
            print('[Skeleton] Newly discovered predicates that should be '
                  f'added to the planning domain: {missing_predicates}')
        else:
            print('[Skeleton] No missing predicates detected.')

        # ------------------------------------------------------------------
        # TODO (user-side):
        #   1.  Insert planner invocation with corrected domain/problem files.
        #   2.  Translate the resulting action sequence into calls to the
        #       predefined skill functions (move, pick, place, rotate, pull).
        # ------------------------------------------------------------------

        # Place-holder wait so the user can observe the environment for a bit
        print('[Skeleton] Standing by for 2 s …')
        time.sleep(2.0)

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print('===== End of Skeleton Task =====')


if __name__ == '__main__':
    run_skeleton_task()