# run_skeleton_task.py (Completed)

import math
import sys
from typing import Set

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment
import skill_code                             # we will monkey-patch helpers into this module
from skill_code import *                      # available skills (rotate, pick, place, move, pull)

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions


# -------------------------------------------------------------------------- #
#  Helper functions required by the rotate skill in skill_code               #
# -------------------------------------------------------------------------- #
def normalize_quaternion(quat: np.ndarray) -> np.ndarray:
    """Return a normalised quaternion (xyzw)."""
    quat = np.asarray(quat, dtype=np.float64)
    norm = np.linalg.norm(quat)
    if norm == 0.0:
        return quat
    return quat / norm


def euler_from_quat(quat: np.ndarray) -> np.ndarray:
    """
    Convert a unit quaternion (xyzw) to intrinsic XYZ Euler angles.
    Returned angles are in radians.
    """
    x, y, z, w = quat
    # roll (x-axis rotation)
    t0 = +2.0 * (w * x + y * z)
    t1 = +1.0 - 2.0 * (x * x + y * y)
    roll = math.atan2(t0, t1)

    # pitch (y-axis rotation)
    t2 = +2.0 * (w * y - z * x)
    t2 = max(min(t2, 1.0), -1.0)
    pitch = math.asin(t2)

    # yaw (z-axis rotation)
    t3 = +2.0 * (w * z + x * y)
    t4 = +1.0 - 2.0 * (y * y + z * z)
    yaw = math.atan2(t3, t4)

    return np.array([roll, pitch, yaw], dtype=np.float64)


# -------------------------------------------------------------------------- #
#  Monkey-patch the missing helpers into skill_code so that rotate() works   #
# -------------------------------------------------------------------------- #
skill_code.normalize_quaternion = normalize_quaternion
skill_code.euler_from_quat = euler_from_quat
sys.modules['skill_code'].normalize_quaternion = normalize_quaternion
sys.modules['skill_code'].euler_from_quat = euler_from_quat


# -------------------------------------------------------------------------- #
#  Simple static analyser to discover predicates that exist in exploration   #
#  domain but are absent in the combined domain.                             #
# -------------------------------------------------------------------------- #
def find_missing_predicates() -> Set[str]:
    """Return the set of predicates that are required but undefined."""
    combined_domain_predicates = {
        # combined-domain predicates
        'at', 'holding', 'handempty',
        'is-locked', 'is-open',
        'rotated', 'gripper-at', 'holding-drawer',
        'is-side-pos', 'is-anchor-pos'
    }

    exploration_domain_predicates = {
        # exploration predicates
        'robot-at', 'identified', 'temperature-known',
        'weight-known', 'durability-known',
        'handempty', 'holding', 'at', 'lock-known'
    }

    return exploration_domain_predicates - combined_domain_predicates


# -------------------------------------------------------------------------- #
#  Main runner                                                                #
# -------------------------------------------------------------------------- #
def run_skeleton_task():
    """Generic skeleton for running any task in your simulation."""
    print("===== Starting Skeleton Task =====")

    # === Report missing predicates (Exploration Phase) ===
    missing_preds = find_missing_predicates()
    if missing_preds:
        print(f"[Exploration] Detected missing predicate(s): {', '.join(sorted(missing_preds))}")
    else:
        print("[Exploration] No missing predicates detected.")

    # === Environment Setup ===
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()

        # (Optional) Initialize video writers
        init_video_writers(obs)

        # Wrap task step & observation for video recording
        task.step = recording_step(task.step)
        task.get_observation = recording_get_observation(task.get_observation)

        # === Retrieve Object Positions ===
        positions = get_object_positions()
        print(f"[Info] Retrieved {len(positions)} object position(s) from the scene.")

        # ------------------------------------------------------------------
        # Example minimal interaction:
        # We simply rotate the gripper 90 degrees in Z to ensure that the
        # rotate skill (and our helper functions) are working correctly.
        # ------------------------------------------------------------------
        target_quat = np.array([0.0, 0.0, math.sin(math.pi / 4), math.cos(math.pi / 4)])
        print("[Demo] Executing a test rotate to verify helpers...")
        try:
            obs, reward, done = rotate(
                env,
                task,
                target_quat=target_quat,
                max_steps=50,
                threshold=0.03,
                timeout=5.0
            )
            print("[Demo] Rotate test completed.")
        except Exception as e:
            print(f"[Warning] Rotate demo failed: {e}")

        # ------------------------------------------------------------------
        # Place holder for your real task logic.
        # Here you could use pick / place / move / pull skills as required.
        # ------------------------------------------------------------------

    finally:
        # Always ensure the environment is properly shutdown
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()