# run_skeleton_task.py – Completed version

import time
import math
import traceback
import numpy as np

from env import setup_environment, shutdown_environment

# pull every symbol from skill_code (pick, place, move, rotate, pull …)
from skill_code import *

from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions



# --------------------------------------------------------------------------- #
# ---------------------------  Helper Functions ----------------------------- #
# --------------------------------------------------------------------------- #

def normalize_quaternion(q):
    """Return a unit-length copy of *q* (xyzw format)."""
    q = np.asarray(q, dtype=np.float64)
    n = np.linalg.norm(q)
    if n < 1e-12:
        # Fallback to identity if the quaternion has zero length
        return np.array([0.0, 0.0, 0.0, 1.0], dtype=np.float64)
    return q / n


def euler_from_quat(q):
    """
    Convert quaternion (xyzw) → Euler angles (roll, pitch, yaw) in radians.
    This helper is intentionally kept simple; it is only used for printing.
    """
    q = normalize_quaternion(q)
    x, y, z, w = q
    # roll (x-axis rotation)
    t0 = +2.0 * (w * x + y * z)
    t1 = +1.0 - 2.0 * (x * x + y * y)
    roll = math.atan2(t0, t1)
    # pitch (y-axis rotation)
    t2 = +2.0 * (w * y - z * x)
    t2 = +1.0 if t2 > +1.0 else t2
    t2 = -1.0 if t2 < -1.0 else t2
    pitch = math.asin(t2)
    # yaw (z-axis rotation)
    t3 = +2.0 * (w * z + x * y)
    t4 = +1.0 - 2.0 * (y * y + z * z)
    yaw = math.atan2(t3, t4)
    return roll, pitch, yaw


# --------------------------------------------------------------------------- #
# --------------------------  Exploration Logic  ---------------------------- #
# --------------------------------------------------------------------------- #

def exploration_phase(env, task, positions):
    """
    Very light-weight ‘exploration’ routine whose purpose is simply to
    demonstrate that the code can iterate over objects/locations, call the
    existing skills, and gather observations that would allow an external
    reasoner to work out missing predicates.  It is intentionally generic.
    """

    print("----- Exploration Phase  (start) -----")
    try:
        # 1) Rotate the gripper to a canonical orientation (identity quaternion)
        #    – this indirectly exercises the rotate() skill and therefore checks
        #      all its helper functions (including the ones we patched in above).
        identity_quat = np.array([0.0, 0.0, 0.0, 1.0])
        obs, reward, done = rotate(
            env=env,
            task=task,
            target_quat=identity_quat,
            max_steps=75,
            threshold=0.05,
            timeout=5.0
        )
        if done:
            print("[Exploration] Task terminated during initial rotation.")
            return

        # 2) Loop over all objects reported by object_positions and print their
        #    locations.  If the ‘move’ skill is available we try to run it once
        #    just so the simulator sees some robot motion.
        for obj_name, pos in positions.items():
            print(f"[Exploration] Object '{obj_name}' @ {np.round(pos, 3)}")

        if 'move' in globals():
            # Attempt a tiny move towards the first object (if any).
            if positions:
                target_obj, target_pos = next(iter(positions.items()))
                print(f"[Exploration] Quick approach towards '{target_obj}'")
                try:
                    # We cannot be certain about the exact signature of move()
                    # because it depends on the external skill implementation.
                    # The safest approach is to use *args / **kwargs and let the
                    # skill validate the parameters.  We therefore supply the
                    # most common triple: (env, task, target_pos).
                    move(env, task, target_pos)
                except TypeError:
                    # Signature mismatch – fall back to a more generic call
                    try:
                        move(task, target_pos)
                    except Exception:
                        pass  # Give up silently; exploration continues.

        # 3) Small delay so that the video captures a few frames even if the
        #    above actions finish instantly.
        for _ in range(5):
            task.step(np.zeros(env.action_shape))
    finally:
        print("----- Exploration Phase  (end) -----")


# --------------------------------------------------------------------------- #
# --------------------------  Main Entry Point ------------------------------ #
# --------------------------------------------------------------------------- #

def run_skeleton_task():
    """Generic skeleton for running any task in your simulation."""
    print("===== Starting Skeleton Task =====")

    env, task = setup_environment()
    try:
        # Reset task
        descriptions, obs = task.reset()

        # Video recording helpers
        init_video_writers(obs)
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        # Obtain current object locations from helper module
        positions = get_object_positions()

        # -- Exploration phase ------------------------------------------------
        exploration_phase(env, task, positions)

        # --------------------------------------------------------------------
        # At this point an external planner could analyse collected data
        # (video, observations, printed logs) to infer which predicate(s) were
        # missing from the model, then deliver a refined plan.  For the scope
        # of this executable, we stop after the exploration to keep things
        # generic and free of task-specific assumptions.
        # --------------------------------------------------------------------

    except Exception as exc:
        print("!!! Exception raised during run_skeleton_task() !!!")
        traceback.print_exc()
    finally:
        shutdown_environment(env)

    print("===== End of Skeleton Task =====")


if __name__ == "__main__":
    run_skeleton_task()