# run_skeleton_task.py  (Filled-in version)

import sys
import traceback
import time
from typing import Dict, Any, List, Tuple

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# Import every predefined skill exactly as it exists in `skill_code`
# NOTE: do **not** redefine or wrap these skills – simply call them.
from skill_code import pick, place, move, rotate, pull     # pylint: disable=unused-import

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# ----------------------------------------------------------------------------------
# Helper utilities (generic, no domain logic)
# ----------------------------------------------------------------------------------

def _safe_skill_call(skill_fn, *args, **kwargs):
    """
    Executes a skill, catching any Exception so that the rest of the plan
    continues.  Returns (obs, reward, done, success_boolean).
    """
    try:
        obs, reward, done = skill_fn(*args, **kwargs)
        return obs, reward, done, True
    except Exception as exc:                              # pylint: disable=broad-except
        # Log full traceback to aid debugging & exploration
        print(f"[Warning] Skill {skill_fn.__name__} failed with exception:\n"
              f"{''.join(traceback.format_exception(None, exc, exc.__traceback__))}")
        return None, None, False, False


def _deduce_missing_predicates(failure_logs: List[str]) -> List[str]:
    """
    Very naive heuristic:
    If the exception message contains a PDDL predicate name that
    is *not* inside the combined domain’s predicate list, report it.
    """
    domain_predicates = {
        'at', 'holding', 'handempty', 'is-locked', 'is-open',
        'rotated', 'gripper-at', 'holding-drawer',
        'is-side-pos', 'is-anchor-pos'
    }

    suspected = set()
    for log in failure_logs:
        # Simple token split – not robust but enough for exploration
        tokens = [tok.strip('() ') for tok in log.replace('\n', ' ').split()]
        for token in tokens:
            if token.islower() and token not in domain_predicates:
                suspected.add(token)

    return sorted(list(suspected))


# ----------------------------------------------------------------------------------
# Main task runner
# ----------------------------------------------------------------------------------

def run_skeleton_task():                       # noqa: C901 – function intentionally long but linear
    """
    A generic yet *executable* routine that:
      1) sets up the RLBench environment;
      2) explores objects with predefined skills in order
         to figure out what may be missing;
      3) performs a simple clean-up plan (pick-and-place);
      4) shuts everything down gracefully.
    """
    print("===== Starting Skeleton Task =====")

    # --------------------------------------------------------------------------
    # 1) Environment setup
    # --------------------------------------------------------------------------
    env, task = setup_environment()
    failure_logs: List[str] = []          # collect exceptions for exploration
    suspected_predicates: List[str] = []

    try:
        # Reset simulation & acquire first observation
        descriptions, obs = task.reset()

        # Initialise (optional) video recorders
        init_video_writers(obs)

        # Wrap step / get_observation to automatically record
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ----------------------------------------------------------------------
        # 2) Obtain environment information
        # ----------------------------------------------------------------------
        object_dict: Dict[str, Any] = get_object_positions()
        if not object_dict:
            print("[Info] No objects reported by get_object_positions(). "
                  "You may want to check object_positions.py.")
        else:
            print(f"[Info] Objects discovered: {list(object_dict.keys())}")

        # Choose a generic disposal position if available, otherwise origin
        disposal_pos: Tuple[float, float, float] = object_dict.get(
            'disposal_zone', (0.0, 0.0, 0.0)
        )

        # ----------------------------------------------------------------------
        # 3) EXPLORATION PHASE – attempt to exercise every predefined skill
        # ----------------------------------------------------------------------
        print("----- Exploration phase: probing predefined skills -----")
        for obj_name, obj_info in object_dict.items():
            if obj_name == 'disposal_zone':
                continue
            position = obj_info if isinstance(obj_info, (list, tuple)) else obj_info.get('position', None)

            print(f"\n[Exploration] -> Object `{obj_name}` at {position}")

            # ------------------------------------------------------------------
            # a) MOVE (simply call move to current target pos if signature matches)
            #    We do a best-effort call: some move() variants require only env & task.
            # ------------------------------------------------------------------
            print("[Exploration]   Trying move() ...")
            try:
                _safe_skill_call(move, env, task, target_pos=position)
            except TypeError:
                # Fallback if move signature differs
                _safe_skill_call(move, env, task)

            # ------------------------------------------------------------------
            # b) PICK
            # ------------------------------------------------------------------
            print("[Exploration]   Trying pick() ...")
            _, _, _, pick_success = _safe_skill_call(
                pick,
                env,
                task,
                target_pos=position,
                approach_distance=0.15,
                max_steps=100,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )

            # ------------------------------------------------------------------
            # c) ROTATE (arbitrary 90-deg rotation around Z)
            # ------------------------------------------------------------------
            print("[Exploration]   Trying rotate() ...")
            ninety_deg_quat = np.array([0, 0, np.sin(np.pi/4), np.cos(np.pi/4)])   # xyzw
            _safe_skill_call(
                rotate,
                env,
                task,
                target_quat=ninety_deg_quat,
                max_steps=100,
                threshold=0.05,
                timeout=5.0
            )

            # ------------------------------------------------------------------
            # d) PULL (only if pick succeeded)
            # ------------------------------------------------------------------
            if pick_success:
                print("[Exploration]   Trying pull() ...")
                _safe_skill_call(pull, env, task)

            # ------------------------------------------------------------------
            # e) PLACE object back (cleanup)
            # ------------------------------------------------------------------
            if pick_success:
                print("[Exploration]   Trying place() ...")
                _safe_skill_call(
                    place,
                    env,
                    task,
                    target_pos=disposal_pos,
                    approach_distance=0.15,
                    max_steps=100,
                    threshold=0.01,
                    approach_axis='z',
                    timeout=10.0
                )

        # ----------------------------------------------------------------------
        # 4) Determine if any predicate appears to be “missing”
        # ----------------------------------------------------------------------
        suspected_predicates = _deduce_missing_predicates(failure_logs)
        if suspected_predicates:
            print("\n===== Suspected *missing* predicates =====")
            for pred in suspected_predicates:
                print(f"  - {pred}")
            print("==========================================")
        else:
            print("\n[Info] Exploration did not reveal obvious missing predicates.")

    finally:
        # ----------------------------------------------------------------------
        # 5) Tear down simulation (always!)
        # ----------------------------------------------------------------------
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()