# run_skeleton_task.py (Completed Implementation)

import time
import numpy as np

from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
from skill_code import *       # <- ALL motion primitives come from this module
from video import (
    init_video_writers,
    recording_step,
    recording_get_observation
)
from object_positions import get_object_positions


# -----------------------------------------------------------------------------
# Helper utilities (do NOT redefine any primitive skills)
# -----------------------------------------------------------------------------
def _normalize_quat(q):
    """Safety-wrapper around quaternion normalisation."""
    q = np.asarray(q).astype(np.float64).flatten()
    if q.shape[0] != 4:
        raise ValueError(f"Quaternion must have 4 elements, got shape {q.shape}")
    norm = np.linalg.norm(q)
    if norm <= 1e-8:
        # Fallback to identity quaternion
        return np.array([0., 0., 0., 1.])
    return q / norm


def _safe_getattr(obj, name, default):
    """Returns attribute value if exists, otherwise returns default."""
    return getattr(obj, name, default)


# -----------------------------------------------------------------------------
# Generic “exploration” phase
# -----------------------------------------------------------------------------
def explore_and_identify(env, task, positions):
    """
    A very light-weight exploration routine that:
    (1) moves close to every location provided by `positions`;
    (2) performs a ‘pick-and-place back’ cycle when possible;
    (3) records any issues so we can infer missing predicates / preconditions.
    """
    print("===== [Exploration] START =====")

    failed_actions = []   # collect information for later reasoning

    for name, pos in positions.items():
        print(f"[Exploration] Handling object: {name}")

        # 1) MOVE
        try:
            print(f"[Exploration]  • move → {name}")
            # We assume the `move` primitive follows same interface as `pick`.
            obs, reward, done = move(
                env, task,
                target_pos=pos,
                approach_distance=0.20,
                max_steps=120,
                threshold=0.015,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Exploration]  • Task signalled DONE during move — abort exploration.")
                return False
        except Exception as exc:
            print(f"[Exploration]  • move FAILED: {exc}")
            failed_actions.append(('move', name, str(exc)))
            continue

        # 2) PICK
        try:
            print(f"[Exploration]  • pick → {name}")
            obs, reward, done = pick(
                env, task,
                target_pos=pos,
                approach_distance=0.05,
                max_steps=120,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Exploration]  • Task signalled DONE during pick — abort exploration.")
                return False
        except Exception as exc:
            print(f"[Exploration]  • pick FAILED: {exc}")
            failed_actions.append(('pick', name, str(exc)))
            # If pick fails we skip place (can’t place what we don’t hold)
            continue

        # 3) PLACE (back to same place for simplicity)
        try:
            print(f"[Exploration]  • place → {name}")
            obs, reward, done = place(
                env, task,
                target_pos=pos,
                approach_distance=0.05,
                max_steps=120,
                threshold=0.01,
                approach_axis='z',
                timeout=10.0
            )
            if done:
                print("[Exploration]  • Task signalled DONE during place — abort exploration.")
                return False
        except Exception as exc:
            print(f"[Exploration]  • place FAILED: {exc}")
            failed_actions.append(('place', name, str(exc)))

    # Simple reasoning about failures
    if failed_actions:
        print("===== [Exploration] SUMMARY OF FAILURES =====")
        for act, obj, msg in failed_actions:
            print(f"  - Action '{act}' on '{obj}' failed because: {msg}")
        print("  > Missing predicates or unmet preconditions may exist. "
              "Refer to failure logs above.")
    else:
        print("[Exploration] All primitives executed without any blocking failure.")

    print("===== [Exploration] END =====")
    return True


# -----------------------------------------------------------------------------
# Example task-specific routine (opens a drawer if present)
# -----------------------------------------------------------------------------
def attempt_open_drawer(env, task, drawer_info):
    """
    Attempts to open a drawer using the skill primitives defined in the domain:
        rotate  → move-to-side → move-to-anchor → pick-drawer → pull
    For brevity, we only cover a very high-level happy-path using ‘rotate’ and ‘pull’.
    """
    if drawer_info is None:
        print("[Drawer] No drawer information provided. Skipping.")
        return

    drawer_pos = drawer_info.get('handle_pos')
    if drawer_pos is None:
        print("[Drawer] Drawer handle position unknown. Skipping.")
        return

    print("[Drawer] Starting open drawer sequence.")

    # 1) Move close to drawer handle
    try:
        obs, reward, done = move(
            env, task,
            target_pos=drawer_pos,
            approach_distance=0.15,
            max_steps=150,
            threshold=0.01,
            approach_axis='xy',   # assume horizontal approach makes sense
            timeout=12.0
        )
        if done:
            return
    except Exception as exc:
        print(f"[Drawer] move to handle FAILED: {exc}")
        return

    # 2) Rotate gripper to 90° (assumed required for side approach)
    ninety_deg_quat = _normalize_quat([0., 0., np.sin(np.pi/4), np.cos(np.pi/4)])
    try:
        rotate(env, task, ninety_deg_quat, max_steps=60, threshold=0.05, timeout=6.0)
    except Exception as exc:
        print(f"[Drawer] rotate FAILED: {exc}")
        return

    # 3) Pull
    try:
        print("[Drawer] Attempting pull.")
        obs, reward, done = pull(env, task)
        if done:
            print("[Drawer] Pull ended task prematurely.")
    except Exception as exc:
        print(f"[Drawer] pull FAILED: {exc}")


# -----------------------------------------------------------------------------
# Main entrypoint
# -----------------------------------------------------------------------------
def run_skeleton_task():
    print("===== Starting Skeleton Task =====")

    # ---------------------------------------------------------------------
    # Environment setup
    # ---------------------------------------------------------------------
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # (Optional) video recorder
        init_video_writers(obs)
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # -----------------------------------------------------------------
        # Collect world information
        # -----------------------------------------------------------------
        positions = get_object_positions()  # user implemented utility
        if not positions:
            print("[Warning] get_object_positions() returned empty dict.")

        # -----------------------------------------------------------------
        # EXPLORATION phase – quickly interact with each object/location
        # -----------------------------------------------------------------
        exploration_successful = explore_and_identify(env, task, positions)
        if not exploration_successful:
            print("[Main] Exploration terminated early due to task completion.")
            return

        # -----------------------------------------------------------------
        # Task-specific logic (example: try to open a drawer if present)
        # -----------------------------------------------------------------
        drawer_info = positions.get('drawer')   # expects dict with at least 'handle_pos'
        attempt_open_drawer(env, task, drawer_info)

        # -----------------------------------------------------------------
        # (OPTIONAL) Any further task goals can be appended here
        # -----------------------------------------------------------------
        print("[Main] Finished all planned actions.")

    finally:
        # Always clean up env resources
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()
