# run_skeleton_task.py

import numpy as np
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from env import setup_environment, shutdown_environment

from skill_code import *  # pick, place, move, rotate, pull, etc.

from video import init_video_writers, recording_step, recording_get_observation

from object_positions import get_object_positions

def run_skeleton_task():
    '''Generic skeleton for running any task in your simulation.'''
    print("===== Starting Skeleton Task =====")
    
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)
        # wrap step and get_observation for recording
        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)
        
        # Retrieve positions of relevant objects
        positions = get_object_positions()
        # Example keys in positions: 'drawer_side_pos', 'drawer_anchor_pos', 
        # 'object_to_manipulate', 'target_destination'
        try:
            drawer_side_pos = positions['drawer_side_pos']
            drawer_anchor_pos = positions['drawer_anchor_pos']
            object_pos      = positions['object_to_manipulate']
            target_pos      = positions['target_destination']
        except KeyError as e:
            print(f"[Error] Missing position for key: {e}")
            return
        
        # Define a 90-degree target orientation quaternion (example around Z axis)
        ninety_deg_quat = [0.0, 0.0, np.sin(np.pi/4), np.cos(np.pi/4)]
        
        # Oracle plan: sequence of (skill_name, kwargs)
        plan = [
            ('rotate',       {'target_quat': ninety_deg_quat}),
            ('move',         {'target_pos': drawer_side_pos, 'max_steps': 100, 'threshold': 0.01, 'timeout': 10.0}),
            ('move',         {'target_pos': drawer_anchor_pos, 'max_steps': 100, 'threshold': 0.01, 'timeout': 10.0}),
            ('pick',         {'target_pos': drawer_anchor_pos, 'approach_distance': 0.10, 'max_steps': 100, 'threshold': 0.01, 'approach_axis': 'z', 'timeout': 10.0}),
            ('pull',         {}),  # pull(env, task)
            ('move',         {'target_pos': object_pos, 'max_steps': 100, 'threshold': 0.01, 'timeout': 10.0}),
            ('pick',         {'target_pos': object_pos, 'approach_distance': 0.10, 'max_steps': 100, 'threshold': 0.01, 'approach_axis': 'z', 'timeout': 10.0}),
            ('move',         {'target_pos': target_pos, 'max_steps': 100, 'threshold': 0.01, 'timeout': 10.0}),
            ('place',        {'target_pos': target_pos, 'approach_distance': 0.10, 'max_steps': 100, 'threshold': 0.01, 'approach_axis': 'z', 'timeout': 10.0})
        ]
        
        # Execute the plan step by step
        for step_idx, (skill_name, kwargs) in enumerate(plan, start=1):
            print(f"[Plan] Step {step_idx}: executing {skill_name} with args {kwargs}")
            skill_fn = globals().get(skill_name)
            if skill_fn is None:
                print(f"[Error] Skill '{skill_name}' is not found.")
                break
            try:
                # call skill; many primitives accept signature (env, task, **kwargs)
                result = skill_fn(env, task, **kwargs)
                # some skills return (obs, reward, done)
                if isinstance(result, tuple) and len(result) == 3:
                    obs, reward, done = result
                else:
                    # fallback if skill returns only obs
                    obs = result
                    reward = None
                    done = False
                if done:
                    print(f"[Task] Task ended prematurely at step {step_idx} ({skill_name}).")
                    break
            except TypeError:
                # try positional arguments if kwarg signature mismatched
                args = list(kwargs.values())
                try:
                    result = skill_fn(env, task, *args)
                    if isinstance(result, tuple) and len(result) == 3:
                        obs, reward, done = result
                    else:
                        obs = result
                        reward = None
                        done = False
                    if done:
                        print(f"[Task] Task ended prematurely at step {step_idx} ({skill_name}).")
                        break
                except Exception as exc:
                    print(f"[Error] Exception during '{skill_name}': {exc}")
                    break
            except Exception as e:
                print(f"[Error] Unexpected error during '{skill_name}': {e}")
                break
        
    finally:
        shutdown_environment(env)
    
    print("===== End of Skeleton Task =====")

if __name__ == "__main__":
    run_skeleton_task()