# main.py
import numpy as np
from env import setup_environment, shutdown_environment
from skill_code import close_gripper, press
from video import init_video_writers, recording_step, recording_get_observation
from object_positions import get_object_positions
from pyrep.objects.joint import Joint

def run_lamp_on_task():
    print("===== Starting LampOn Task =====")
    env, task = setup_environment()
    try:
        descriptions, obs = task.reset()
        init_video_writers(obs)

        original_step = task.step
        task.step = recording_step(original_step)
        original_get_obs = task.get_observation
        task.get_observation = recording_get_observation(original_get_obs)

        positions = get_object_positions()
        button_pos = positions['button_pos']
        
        button_pos_adjusted = button_pos.copy()
        button_pos_adjusted[2] += 0.015

        print(f"[Task] Original button position: {button_pos}")
        print(f"[Task] Adjusted button position: {button_pos_adjusted}")

        button_joint = Joint('target_button_joint')
        initial_joint_pos = button_joint.get_position()
        print(f"[Debug] Initial button joint position: {initial_joint_pos}")

        print("[Task] Closing gripper")
        obs, reward, done = close_gripper(
            env, task,
            max_steps=10,
            timeout=2.0
        )
        if done:
            print("[Task] Task ended after closing gripper!")
            return

        print("[Task] Pressing the button")
        obs, reward, done = press(
            env, task,
            target_pos=button_pos_adjusted,
            max_steps=100,
            threshold=0.005,
            timeout=10.0
        )

        final_joint_pos = button_joint.get_position()
        print(f"[Debug] Final button joint position: {final_joint_pos}")
        print(f"[Debug] Button joint displacement: {np.linalg.norm(final_joint_pos - initial_joint_pos)}")

        if done:
            print("[Task] Task completed successfully! Reward:", reward)
        else:
            print("[Task] Task not completed (done=False).")

    finally:
        shutdown_environment(env)

    print("===== End of LampOn Task =====")

if __name__ == "__main__":
    run_lamp_on_task()