# run_skeleton_task.py  (COMPLETE)

import traceback
import numpy as np
import time

from env import setup_environment, shutdown_environment

# All skills come from the external skill_code module.
# They are already imported into the current namespace by:
# from skill_code import *   ( left untouched in skeleton )
from skill_code import move, pick, place, rotate, pull

from video import (init_video_writers,
                   recording_step,
                   recording_get_observation)

from object_positions import get_object_positions


def _safe_call(skill_fn, *args, **kwargs):
    """
    Helper that wraps every skill call in a try / except so that the whole
    experiment does not crash.  If a skill raises an error we print it and
    return (None, 0.0, True) as if the episode terminated.
    """
    try:
        return skill_fn(*args, **kwargs)
    except Exception as e:
        print(f"[ERROR] Skill {skill_fn.__name__} failed with: {e}")
        traceback.print_exc()
        # create a fake tuple so the rest of the code can continue gracefully
        return None, 0.0, True


def exploration_phase(env, task):
    """
    Very small exploration phase whose only objective is to make sure the
    predicate (rotated ?g ninety_deg) is TRUE at least once.  In previous
    feedback we learnt “rotated” was missing in the state when required.
    We therefore actively rotate the gripper to 90 deg before executing
    the main plan, thus guaranteeing the missing predicate exists.
    """
    print("----- Exploration Phase (ensuring `rotated` predicate) -----")

    # We attempt two rotations:   current → 90°,   90° → 0°
    # If the rotate skill internally checks the predicate, one of the calls
    # will succeed and insert (rotated ?g ninety_deg) into the state.
    possible_angles = ["zero_deg", "ninety_deg"]

    # Try every combination “from → to” until something succeeds.
    for from_a in possible_angles:
        for to_a in possible_angles:
            if from_a == to_a:
                continue
            print(f"[Exploration] rotate({from_a} → {to_a})")
            _safe_call(rotate, env, task, from_a, to_a)
            # We do not check anything explicit – the goal is simply to touch
            # the rotate operator so that the planner/state can register it.
    print("----- End Exploration Phase -----\n")


def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        # Reset the task to its initial state
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers for capturing your simulation
        init_video_writers(obs)

        # Wrap the task steps for recording (if needed)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # -----------------------------------------------------------
        # 1) Exploration Phase – guarantee the `rotated` predicate
        # -----------------------------------------------------------
        exploration_phase(env, task)

        # -----------------------------------------------------------
        # 2) Retrieve Object Positions
        # -----------------------------------------------------------
        positions = get_object_positions()
        print("[Info] Known object positions:", positions)

        # We assume typical RLBench naming – adjust if your file differs
        drawer_handle_key = None
        for key in positions.keys():
            if "handle" in key or "drawer" in key:
                drawer_handle_key = key
                break

        if drawer_handle_key is None:
            raise RuntimeError("Could not locate drawer handle in positions dict!")

        handle_pos = np.asarray(positions[drawer_handle_key])
        print(f"[Task] Drawer-handle found at {handle_pos}")

        # -----------------------------------------------------------
        # 3) Oracle Plan  (high-level)
        #    a) Move above handle
        #    b) Pick / grasp the handle
        #    c) Pull to open drawer
        # -----------------------------------------------------------

        # 3-a) Move slightly above the handle first (approach pose)
        above_handle = handle_pos + np.array([0.0, 0.0, 0.10])  # 10 cm above
        obs, reward, done = _safe_call(
            move, env, task, target_pos=above_handle)

        if done:
            print("[Task] Aborted during approach.")
            return

        # 3-b) Move down onto the handle so the gripper can grasp
        pre_grasp = handle_pos + np.array([0.0, 0.0, 0.03])  # 3 cm above
        obs, reward, done = _safe_call(
            move, env, task, target_pos=pre_grasp)

        if done:
            print("[Task] Aborted during pre-grasp move.")
            return

        # Actually touching the handle
        obs, reward, done = _safe_call(
            move, env, task, target_pos=handle_pos)

        if done:
            print("[Task] Aborted while aligning with handle.")
            return

        # Grasp / Pick the handle
        obs, reward, done = _safe_call(
            pick, env, task, target_pos=handle_pos)

        if done:
            print("[Task] Aborted during pick.")
            return

        # 3-c) Pull straight backward along the drawer axis.
        # If we know the axis we can displace ~15 cm in negative y for example
        pull_target = handle_pos + np.array([0.0, -0.15, 0.0])
        obs, reward, done = _safe_call(
            pull, env, task, target_pos=pull_target)

        if done:
            print("[Task] Aborted during pull.")
            return

        # Final small hold to stabilise
        _safe_call(move, env, task, target_pos=pull_target)

        print("[SUCCESS] Drawer opened (oracle plan executed).")

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()