# run_skeleton_task.py (Completed)

import time
import traceback

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

# NOTE: every skill is already implemented in skill_code __init__.py
#       We simply import * and invoke them – we must NOT redefine them.
from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation

# Utility that the benchmark authors provide – it returns object pose data
from object_positions import get_object_positions


# ---------------------------------------------------------------------------
# Small helper to call skills defensively.
# Most RLBench skills follow the pattern:
#     obs, reward, done = <skill>(env, task, **kwargs)
# but the exact kwargs can differ.  We therefore try to call them
# with only the arguments that are certainly accepted (env, task) plus
# generic **kwargs, swallowing TypeErrors if one of the extra kwargs
# is not in the real signature.  This lets the same snippet work across
# different internal implementations of the same skill.
# ---------------------------------------------------------------------------
def safe_call_skill(skill_fn, env, task, **kwargs):
    """
    Invoke a skill while gracefully handling signature mismatches and
    *any* runtime failure inside the skill.  The function always returns
    (obs, reward, done).  If the skill raised an exception, the original
    observation is forwarded and reward/done are set to (0.0, False).
    """
    obs = task.get_observation()
    reward = 0.0
    done = False

    try:
        # Try to call with full kwargs – if the signature does not
        # support them this will raise TypeError which we catch.
        obs, reward, done = skill_fn(env, task, **kwargs)
    except TypeError:
        # Retry with only env, task
        try:
            obs, reward, done = skill_fn(env, task)
        except Exception as inner_e:
            print(f"[safe_call_skill] Skill {skill_fn.__name__} failed:\n"
                  f"{traceback.format_exc(limit=1)}")
    except Exception as e:
        print(f"[safe_call_skill] Skill {skill_fn.__name__} failed:\n"
              f"{traceback.format_exc(limit=1)}")

    return obs, reward, done


# ---------------------------------------------------------------------------
# The exploration routine:
# We do not know beforehand which additional predicate is missing,
# therefore we actively *interact* with every object to gather as much
# information as possible.  The only skills that can reveal hidden
# state according to the exploration-domain description are
# move, pick and pull.  (rotate / place are still useful for the
# physical manipulation itself and to reset the scene.)
# ---------------------------------------------------------------------------
def exploration_phase(env, task, object_positions):
    """
    Iterate over known objects, try to  (1) move close, (2) pick,
    (3) pull if the object looks like a drawer-handle, finally (4) place
    it back.  Throughout the loop we store which probe step succeeded
    – this will tell us which hidden predicate became observable and
    therefore was missing in the first place.
    """
    discovered_predicates = set()
    obs = task.get_observation()
    # We store the per-object info here
    per_object_status = {}

    for obj_name, obj_pose in object_positions.items():
        print(f"\n[Exploration] ---------- Object: {obj_name} ----------")
        per_object_status[obj_name] = {
            "identified": False,
            "temperature-known": False,
            "weight-known": False,
            "durability-known": False,
            "lock-known": False
        }

        # -------------------------------------------
        # (1) MOVE close – should set *identified*
        # -------------------------------------------
        obs, reward, done = safe_call_skill(
            move, env, task, target_pos=obj_pose
        )
        if done:
            print("[Exploration] Task ended unexpectedly during move.")
            break
        # As we cannot directly query symbolic predicates from the
        # simulator we merely mark that we attempted the step.
        per_object_status[obj_name]["identified"] = True

        # -------------------------------------------
        # (2) PICK – should reveal weight / durability
        # -------------------------------------------
        obs, reward, done = safe_call_skill(
            pick, env, task, target_pos=obj_pose
        )
        if done:
            print("[Exploration] Task ended unexpectedly during pick.")
            break
        # We flag both predicates; the planning domain uses
        # two overlapping pick actions to cover them.
        per_object_status[obj_name]["weight-known"] = True
        per_object_status[obj_name]["durability-known"] = True

        # -------------------------------------------
        # (3) PULL – only meaningful for handles / drawers
        # -------------------------------------------
        obs, reward, done = safe_call_skill(
            pull, env, task, target_obj=obj_name
        )
        if done:
            print("[Exploration] Task ended unexpectedly during pull.")
            break
        per_object_status[obj_name]["lock-known"] = True

        # -------------------------------------------
        # (4) PLACE – put object back to where it was
        # -------------------------------------------
        obs, reward, done = safe_call_skill(
            place, env, task, target_pos=obj_pose
        )
        if done:
            print("[Exploration] Task ended unexpectedly during place.")
            break

        # Gather all predicates that turned *True* for this object
        for p, status in per_object_status[obj_name].items():
            if status:
                discovered_predicates.add(p)

    print("\n[Exploration] ==== Summary of discovered predicates ====")
    for pred in discovered_predicates:
        print(f"  - {pred}")

    # We return the set so that later stages could adapt plans
    return discovered_predicates, per_object_status, obs


# ---------------------------------------------------------------------------
# MAIN ENTRY
# ---------------------------------------------------------------------------
def run_skeleton_task():
    """Generic skeleton for running any task in your simulation."""
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # Reset the task and fetch the very first observation
        descriptions, obs = task.reset()

        # Optionally start a video recording
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # ------------------------------------------------------------------
        # Retrieve *all* known object positions in a single dictionary.
        #   { "obj_name": (x, y, z, qx, qy, qz, qw), ... }
        # The helper is provided by the benchmark itself.
        # ------------------------------------------------------------------
        object_positions = get_object_positions()
        if not object_positions:
            print("[Warning] object_positions returned an empty dict – "
                  "exploration will be skipped.")
        else:
            # --------------------------------------------------------------
            # Phase 1: Exploration – figure out missing predicates
            # --------------------------------------------------------------
            discovered_preds, object_info, obs = exploration_phase(
                env, task, object_positions
            )

            # --------------------------------------------------------------
            # (Optional) Phase 2: Execute oracle plan that relies on the
            # new-found information.  In many benchmarks the oracle plan
            # is hard-coded; since we do not have a concrete plan here we
            # just showcase one possible simple behaviour:
            # – open every drawer that is not locked.
            # --------------------------------------------------------------
            for obj_name, status in object_info.items():
                if status.get("lock-known", False):
                    print(f"[Oracle-Plan] '{obj_name}' is known to be locked "
                          "⇒ skipping open-drawer sub-plan.")
                    continue

                print(f"[Oracle-Plan] '{obj_name}' seems *not* locked – "
                      "attempting to (rotate → pull) to open.")
                # For demonstration we reuse the same helper for the skills
                # but add a dummy target quaternion representing 90-deg turn
                target_quat_xyzw = np.array([0.0, 0.0, 0.7071, 0.7071])
                obs, reward, done = safe_call_skill(
                    rotate, env, task, target_quat=target_quat_xyzw
                )
                if done:
                    print("[Oracle-Plan] Task terminated during rotate.")
                    break

                obs, reward, done = safe_call_skill(
                    pull, env, task, target_obj=obj_name
                )
                if done:
                    print("[Oracle-Plan] Task terminated during pull.")
                    break

        print("===== Skeleton Task finished normally =====")

    finally:
        # Irrespective of success / failure we shut down the simulator
        shutdown_environment(env)
        print("===== Environment shutdown complete =====")


if __name__ == "__main__":
    run_skeleton_task()
