# run_skeleton_task.py (Completed Implementation)

import os
import time
import numpy as np
from typing import Dict, Tuple, List

# RLBench / task-specific utilities
from env import setup_environment, shutdown_environment
from object_positions import get_object_positions

# Recording helpers
from video import init_video_writers, recording_step, recording_get_observation

# Pre-implemented skills – **do not redefine**
from skill_code import *       # noqa: F403, F401   <- gives access to pick, place, move, rotate, pull


###############################################################################
# Helper utilities                                                             #
###############################################################################
DOMAIN_PDDL_STR: str = """
(define (domain combined-domain)
  (:requirements :strips :typing :negative-preconditions :equality :disjunctive-preconditions)
  (:types
    object
    location
    drawer - object
    gripper - object
    position - object
    angle - object
  )
  (:predicates
    (at ?obj - object ?loc - location)
    (holding ?obj - object)
    (handempty)
    (is-locked ?d - drawer)
    (is-open ?d - drawer)
    (rotated ?g - gripper ?a - angle)
    (gripper-at ?g - gripper ?p - position)
    (holding-drawer ?g - gripper ?d - drawer)
    (is-side-pos ?p - position ?d - drawer)
    (is-anchor-pos ?p - position ?d - drawer)
  )
)
"""


def extract_domain_predicates(domain_str: str) -> List[str]:
    """
    Very small, robust extractor that returns every top-level predicate name
    defined in a PDDL domain string.  (Anything that appears right after the
    parenthesis in the :predicates block.)
    """
    preds: List[str] = []
    in_pred_block: bool = False
    for line in domain_str.splitlines():
        l = line.strip()
        if l.startswith("(:predicates"):
            in_pred_block = True
            # remove leading '(:predicates'
            l = l[len("(:predicates"):].strip()
        if in_pred_block:
            # keep scanning until we see a close paren on a line by itself
            if l.endswith(")"):
                # detect end of block
                # we still want to process the line we just read (it may contain
                # a predicate right before the final ")").
                # Strip the trailing ')'
                l = l[:-1].strip()
                if not l:                     # block ended, break after handling
                    break
                # else we continue to process the residue below, then break.
                terminate = True
            else:
                terminate = False

            # collect predicate names
            if l.startswith("("):
                l = l[1:]  # remove leading '('
            if l:
                pred_name = l.split()[0]
                preds.append(pred_name)

            if terminate:
                break
    return preds


def find_missing_predicates(domain_predicates: List[str],
                            feedback_list: List[str]) -> List[str]:
    """Return which feedback predicates are *not* in the domain list."""
    return [p for p in feedback_list if p not in domain_predicates]


###############################################################################
# Main task runner                                                             #
###############################################################################
def run_skeleton_task() -> None:
    """
    Generic runner that ❶ checks for feedback-specified missing predicates
    (exploration phase) and ❷ executes a tiny illustrative plan using only the
    predefined skills.  The illustrative plan simply picks the first visible
    object (if any) and puts it back to demonstrate end-to-end execution.
    """
    print("============  STARTING SKELETON TASK  ============\n")

    # 1) ---------------- Environment setup ----------------
    env, task = setup_environment()
    try:
        # RLBench reset: returns (descriptions, obs)
        descriptions, obs = task.reset()
        print("[Env] Task reset completed.")

        # Optionally create video writers
        init_video_writers(obs)

        # Wrap task.step / task.get_observation so that every frame is recorded
        original_step_fn = task.step
        task.step = recording_step(original_step_fn)
        original_get_obs_fn = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs_fn)

        # 2) ---------------- Exploration Phase -------------
        # Goal: use the feedback array to verify which predicates are *missing*
        #       from the domain definition (handempty, etc.).
        feedback_predicates: List[str] = ["handempty"]     # <- from feedback
        domain_predicates: List[str] = extract_domain_predicates(DOMAIN_PDDL_STR)

        missing_preds: List[str] = find_missing_predicates(domain_predicates,
                                                           feedback_predicates)

        if not missing_preds:
            print("[Exploration] All feedback predicates already present in domain.")
        else:
            print("[Exploration] Missing predicate(s) w.r.t. feedback:", missing_preds)
            # In a real exploration phase, here we would design actions that
            # gather the information necessary to ‘discover’ and incorporate the
            # predicates.  For the competition setting we just log them.

        # 3) --------------- Retrieve object positions --------------
        try:
            positions: Dict[str, Tuple[float, float, float]] = get_object_positions()
            print("[Env] Object positions fetched from helper:", positions)
        except Exception as exc:
            print("[Warning] Could not fetch object positions:", exc)
            positions = {}

        # Early exit if no objects are known (nothing else to do in skeleton)
        if not positions:
            print("[Info] No object positions available – skipping skill demo.")
            return

        # 4) --------------- Simple Demonstration Plan ----------------
        #
        # We show an extremely small plan that:
        #   a) Picks the very first object found
        #   b) Immediately places it back to its original spot
        #
        # The purpose is only to demonstrate correct invocation of existing
        # skills within the required control-flow structure.
        #
        first_obj_name: str = list(positions.keys())[0]
        first_obj_pos: Tuple[float, float, float] = positions[first_obj_name]

        print(f"\n[Plan] Will PICK '{first_obj_name}' at {first_obj_pos}")
        obs, reward, done = pick(          # noqa: F405 (imported via wildcard)
            env,
            task,
            target_pos=np.array(first_obj_pos),
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis="z",
            timeout=10.0,
        )
        if done:
            print("[Plan] Task finished unexpectedly during PICK.")
            return

        # Optional small pause to observe the scene
        for _ in range(5):
            # zero-action for idling
            task.step(np.zeros(env.action_shape))

        print(f"[Plan] Will PLACE '{first_obj_name}' back to {first_obj_pos}")
        obs, reward, done = place(         # noqa: F405
            env,
            task,
            target_pos=np.array(first_obj_pos),
            approach_distance=0.15,
            max_steps=120,
            threshold=0.01,
            approach_axis="z",
            timeout=10.0,
        )
        if done:
            print("[Plan] Task finished after PLACE.")
        else:
            print("[Plan] Skeleton plan completed (object placed).")

    finally:
        # 5) ---------------- Shutdown & cleanup --------------
        shutdown_environment(env)
        print("\n============  END OF SKELETON TASK  ============")


# --------------------------------------------------------------------------- #
if __name__ == "__main__":
    run_skeleton_task()
