"""
RLBench environment wrapper for dppo evaluation framework
"""

import os
import pickle
import logging
import numpy as np
import gym
from gym import spaces

log = logging.getLogger(__name__)

# Fix cv2 Qt plugin path issue - preserve CoppeliaSim plugin path
# cv2 may override QT_QPA_PLATFORM_PLUGIN_PATH, so we need to fix it
_COPPELIASIM_ROOT = os.environ.get('COPPELIASIM_ROOT', '/work/placeholder/placeholder/apps/CoppeliaSim')
os.environ['QT_QPA_PLATFORM_PLUGIN_PATH'] = _COPPELIASIM_ROOT
# Ensure CoppeliaSim is first in LD_LIBRARY_PATH
_ld_path = os.environ.get('LD_LIBRARY_PATH', '')
if _ld_path.split(':')[0] != _COPPELIASIM_ROOT:
    os.environ['LD_LIBRARY_PATH'] = f"{_COPPELIASIM_ROOT}:{_ld_path}"

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment
from pyrep.objects.joint import Joint
from pyrep.objects.dummy import Dummy
import imageio

# Close drawer task configuration (matching data collection settings)
CLOSE_DRAWER_CONFIG = {
    'cabinet_position': [0.3, -0.6, 0.77],
    'cabinet_orientation': [0, 0, 0],  # No rotation
    'drawer_open_amount': 0.12,  # 0.08 push + 0.04 margin
    'drawer_variation': 2,  # top drawer
    # Custom HOME_JOINTS for top drawer (adjusted to match drawer height)
    # Original RLBench default has EE too high (Z~1.47), this lowers it to Z~1.16
    'home_joints': np.array([0, 0.15, 0, -1.7, 0, 1.7, 0.785]),
}


def extract_state_from_obs(obs):
    """
    Extract state vector from RLBench observation
    """
    state = np.concatenate([
        obs.joint_positions,        # 7
        obs.joint_velocities,       # 7
        np.array([obs.gripper_open], dtype=np.float32),  # 1
        obs.gripper_pose[:3],       # position: 3
        obs.gripper_pose[3:],       # quaternion: 4
    ])
    return state


def reset_robot_to_default(task_env):
    """Reset the robot arm to its default (initial) joint positions."""
    robot = task_env._scene.robot
    default_joints = task_env._scene._start_arm_joint_pos
    robot.arm.set_joint_positions(default_joints, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    robot.gripper.release()
    default_gripper = task_env._scene._starting_gripper_joint_pos
    robot.gripper.set_joint_positions(default_gripper, disable_dynamics=True)
    for _ in range(5):
        task_env._scene.pyrep.step()


def set_robot_home_joints(task_env, home_joints):
    """Set robot arm to specific joint positions (for close_drawer task)."""
    robot = task_env._scene.robot
    robot.arm.set_joint_positions(home_joints, disable_dynamics=True)
    robot.arm.set_joint_target_velocities([0] * 7)
    robot.gripper.release()
    default_gripper = task_env._scene._starting_gripper_joint_pos
    robot.gripper.set_joint_positions(default_gripper, disable_dynamics=True)
    for _ in range(10):
        task_env._scene.pyrep.step()


def fix_cabinet_orientation(task_env):
    """Position and orient the cabinet so the drawer faces the robot."""
    task_base = task_env._task.get_base()
    task_base.set_position(CLOSE_DRAWER_CONFIG['cabinet_position'])
    task_base.set_orientation(CLOSE_DRAWER_CONFIG['cabinet_orientation'])
    for _ in range(5):
        task_env._scene.pyrep.step()


def set_drawer_open(task_env, drawer_idx, open_amount):
    """Set the drawer to an open position (so we can close it)."""
    drawer_names = ["bottom", "middle", "top"]
    for idx, name in enumerate(drawer_names):
        joint = Joint(f'drawer_joint_{name}')
        if idx == drawer_idx:
            joint.set_joint_position(open_amount, disable_dynamics=True)
        else:
            joint.set_joint_position(0.0, disable_dynamics=True)
    for _ in range(5):
        task_env._scene.pyrep.step()


# Grasp task configuration
GRASP_CONFIG = {
    'home_joints': np.array([0, -0.3, 0, -2.0, 0, 1.8, 0.785]),
}


def set_cup_position(task_env, position, cup_name="cup1", max_retries=5, tolerance=0.005):
    """Set the cup to a specific position with verification."""
    from pyrep.objects.shape import Shape
    cup = Shape(cup_name)
    target_pos = np.array(position)

    for attempt in range(max_retries):
        cup.set_position(target_pos)
        for _ in range(10):
            task_env._scene.pyrep.step()

        actual_pos = np.array(cup.get_position())
        error = np.linalg.norm(actual_pos - target_pos)

        if error < tolerance:
            return True

    print(f"Warning: Cup position error {error*100:.2f}cm after {max_retries} attempts")
    return False


class RLBenchEnv(gym.Env):
    """
    Single RLBench environment wrapper compatible with dppo framework
    """

    def __init__(
        self,
        task_name,
        variation=0,
        headless=True,
        record_video=False,
        video_path=None,
        use_demo_reset=False,
        demo_dir=None,
        action_repeat=1,
        collect_full_trajectory=False,  # If True, don't terminate early on success
        fixed_cup_position=None,  # For grasp task: fixed cup position [x, y, z]
    ):
        super().__init__()
        self.task_name = task_name
        self.variation = variation
        self.record_video = record_video
        self.video_path = video_path
        self.video_frames = []
        self.use_demo_reset = use_demo_reset
        self.demo_dir = demo_dir
        self.action_repeat = max(1, int(action_repeat))
        self.collect_full_trajectory = collect_full_trajectory
        self.fixed_cup_position = np.array(fixed_cup_position) if fixed_cup_position is not None else None
        self._demo_index = 0
        self._demo_paths = []
        self._fixed_start_pos = None
        self._fixed_target_pos = None
        self._success_recorded = False  # Track if success was already recorded
        self._success_step = None  # Step at which success was first recorded
        self._handle_pos = None  # Handle position for close_drawer task (for contact checking)
        self._drawer_joint = None  # Drawer joint for tracking movement
        self._initial_drawer_pos = None  # Initial drawer joint position (open)
        self._drawer_started_moving = False  # Flag to track if drawer started closing
        self._contact_ee_pos = None  # EE position when drawer started moving (for contact check)
        self._prev_ee_pos = None  # Previous EE position for tracking contact moment

        if self.use_demo_reset:
            if self.demo_dir is None:
                raise ValueError("demo_dir must be provided when use_demo_reset=True.")

            # Check for fixed endpoints
            fixed_start_path = os.path.join(self.demo_dir, "fixed_start_pos.npy")
            fixed_target_path = os.path.join(self.demo_dir, "fixed_target_pos.npy")

            if os.path.exists(fixed_start_path) and os.path.exists(fixed_target_path):
                self._fixed_start_pos = np.load(fixed_start_path)
                self._fixed_target_pos = np.load(fixed_target_path)
                print(f"Loaded fixed start: {self._fixed_start_pos}")
                print(f"Loaded fixed target: {self._fixed_target_pos}")
            else:
                # Standard demo reset - load demo paths
                for root, _, files in os.walk(self.demo_dir):
                    for fname in files:
                        if fname == "low_dim_obs.pkl":
                            self._demo_paths.append(os.path.join(root, fname))
                if len(self._demo_paths) == 0:
                    raise RuntimeError(f"No demos found under {self.demo_dir}")

        # Configure observation to be state-only, but add cameras if recording video
        obs_config = ObservationConfig()
        obs_config.set_all(False)
        obs_config.joint_positions = True
        obs_config.joint_velocities = True
        obs_config.gripper_open = True
        obs_config.gripper_pose = True

        # Enable camera for video recording (use front camera with custom position)
        if self.record_video:
            obs_config.front_camera.rgb = True
            obs_config.front_camera.image_size = (256, 256)

        # Setup action mode (absolute joint positions)
        # Custom JointPosition to match data collection exactly:
        # - For close_drawer: use dynamics for push phase to physically interact with drawer
        # - This matches push_utils.py behavior in data collection
        class CustomJointPosition(JointPosition):
            """Modified JointPosition that handles dynamics for push tasks"""
            def __init__(self, absolute_mode=True, convergence_steps=5, task_name=None):
                super().__init__(absolute_mode)
                self.convergence_steps = convergence_steps
                self.task_name = task_name
                self.intermediate_ee_positions = []  # Track EE positions during convergence
                self.step_count = 0  # Track which step we're on
                self.phase_steps = {'reach': 64, 'push': 24}  # From close_drawer_config.py

            def _is_push_phase(self):
                """Check if current step is in push phase (for close_drawer task)"""
                if self.task_name != 'close_drawer':
                    return False
                # Push phase starts after reach phase (step 64+)
                return self.step_count >= self.phase_steps['reach']

            def action_pre_step(self, scene, action):
                """Set robot to target joint positions"""
                # Get joint positions from action (first 7 values)
                joint_action = action[:7] if len(action) > 7 else action

                # For close_drawer push phase: use set_joint_target_positions for physics interaction
                # For all other cases: use set_joint_positions with disable_dynamics
                if self._is_push_phase():
                    # Push phase: enable dynamics so gripper physically pushes the drawer
                    # This matches data collection in push_utils.py line 724-725
                    scene.robot.arm.set_joint_target_positions(joint_action)
                else:
                    # Reach phase or other tasks: instant teleport
                    # This matches data collection in push_utils.py line 721
                    scene.robot.arm.set_joint_positions(joint_action, disable_dynamics=True)
                scene.robot.arm.set_joint_target_velocities([0] * 7)

            def action_step(self, scene, action):
                """Step simulation and record EE positions"""
                self.intermediate_ee_positions = []  # Reset for this action
                robot = scene.robot
                tip = robot.arm.get_tip()

                # Use more steps during push phase for physics to work (matching data collection)
                # Data collection: 5 steps for reach, 10 steps (5*2) for push
                if self._is_push_phase():
                    sim_steps = self.convergence_steps * 2  # 10 steps for push
                else:
                    sim_steps = self.convergence_steps  # 5 steps for reach

                for _ in range(sim_steps):
                    scene.step()
                    # Also step the task to update success conditions
                    if hasattr(scene, 'task') and scene.task is not None:
                        scene.task.step()
                    # Record EE position after each sim step for visualization
                    ee_pos = tip.get_position()
                    self.intermediate_ee_positions.append(np.array(ee_pos))

                self.step_count += 1

            def action_post_step(self, scene, action):
                """Update targets to current positions (standard cleanup)"""
                scene.robot.arm.set_joint_target_positions(
                    scene.robot.arm.get_joint_positions())

            def reset_step_count(self):
                """Reset step counter for new episode"""
                self.step_count = 0

        class CustomMoveArmThenGripper(MoveArmThenGripper):
            def action_bounds(self):
                ACT_MIN = [
                    -2.8973000049591064,
                    -1.7627999782562256,
                    -2.8973000049591064,
                    -3.0717999935150146,
                    -2.8973000049591064,
                    -0.017500000074505806,
                    -2.8973000049591064,
                    0.0,
                ]
                ACT_RANGE = [
                    5.794600009918213,
                    3.525599956512451,
                    5.794600009918213,
                    3.002000093460083,
                    5.794600009918213,
                    3.7699999809265137,
                    5.794600009918213,
                    1.0,
                ]
                return (
                    np.array(ACT_MIN, dtype=np.float32),
                    np.array(ACT_MIN, dtype=np.float32) + np.array(ACT_RANGE, dtype=np.float32),
                )

        # Store custom action mode to access intermediate positions later
        # convergence_steps=5 matches data collection (5 for reach, 10 for push via 5*2)
        self.custom_joint_position = CustomJointPosition(True, convergence_steps=5, task_name=task_name)
        action_mode = CustomMoveArmThenGripper(self.custom_joint_position, Discrete())

        # Create environment
        self.env = Environment(
            action_mode=action_mode,
            obs_config=obs_config,
            headless=headless
        )
        self.env.launch()

        # Load task
        task_class = task_file_to_task_class(task_name)
        self.task = self.env.get_task(task_class)

        # Set up front camera position/orientation for video recording
        # (matching data generation in RLBench_close_drawer/make_dataset)
        if self.record_video:
            front_cam = self.task._scene._cam_front
            # Camera position and orientation from close_drawer_config.py
            front_cam.set_position([1.25, 0.4, 1.58])
            front_cam.set_orientation([2.158, -0.934, 0.488])

        # Get action bounds
        self.action_low, self.action_high = action_mode.action_bounds()

        # Define observation and action spaces (required by gym.Env)
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(22,),  # joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            dtype=np.float32
        )
        self.action_space = spaces.Box(
            low=self.action_low,
            high=self.action_high,
            dtype=np.float32
        )

    def reset(self, **kwargs):
        """Reset environment and return initial observation"""
        # Save video from previous episode if recording
        if self.record_video and len(self.video_frames) > 0 and self.video_path is not None:
            self._save_video()

        # Handle video_path passed through kwargs or options dict (for new episodes)
        options = kwargs.get('options', {})
        if 'video_path' in kwargs:
            self.video_path = kwargs.pop('video_path')
            self.record_video = True
        elif 'video_path' in options:
            self.video_path = options.pop('video_path')
            self.record_video = True

        self.video_frames = []
        self._step_count = 0
        self._success_recorded = False  # Reset success flag for new episode
        self._success_step = None  # Reset success step for new episode
        # Reset step counter for phase tracking (used by CustomJointPosition for close_drawer)
        if hasattr(self, 'custom_joint_position'):
            self.custom_joint_position.reset_step_count()
        if self.use_demo_reset:
            if self._fixed_start_pos is not None and self._fixed_target_pos is not None:
                # Fixed endpoints reset
                self.task.set_variation(self.variation)
                _, obs = self.task.reset()

                # Set target to fixed position
                self.task._task.target.set_position(self._fixed_target_pos)

                # Step simulation to update target position and sensors
                for _ in range(5):
                    self.task._scene.pyrep.step()

                # Move robot to fixed start position
                robot = self.task._scene.robot
                tip = robot.arm.get_tip()

                # Get current orientation to maintain it
                current_ori = tip.get_orientation()

                try:
                    # Iteratively refine position to get EXACT placement
                    max_iterations = 5
                    tolerance = 0.001  # 1mm tolerance

                    for iteration in range(max_iterations):
                        # Solve IK for fixed start position
                        start_joints = robot.arm.solve_ik(self._fixed_start_pos, euler=current_ori)
                        robot.arm.set_joint_positions(start_joints)

                        # Step simulation to let robot reach position
                        for _ in range(30):
                            self.task._scene.pyrep.step()

                        # Check if we're close enough
                        current_pos = tip.get_position()
                        pos_error = np.linalg.norm(current_pos - self._fixed_start_pos)

                        if pos_error < tolerance:
                            break

                        # If not close enough, adjust and try again
                        if iteration < max_iterations - 1:
                            # Compute correction
                            correction = self._fixed_start_pos - current_pos
                            corrected_target = self._fixed_start_pos + correction * 0.5  # Partial correction to avoid overshoot

                            # Solve IK for corrected position
                            try:
                                start_joints = robot.arm.solve_ik(corrected_target, euler=current_ori)
                                robot.arm.set_joint_positions(start_joints)
                                for _ in range(10):
                                    self.task._scene.pyrep.step()
                            except:
                                pass  # If correction fails, keep current position

                    # Get final observation
                    obs = self.task._scene.get_observation()

                    # Verify the position
                    final_pos = tip.get_position()
                    pos_error = np.linalg.norm(final_pos - self._fixed_start_pos)
                    target_pos = self.task._task.target.get_position()
                    target_error = np.linalg.norm(target_pos - self._fixed_target_pos)

                    # # Debug logging to understand reset issues
                    # if self._demo_index <= 3:  # Only log first few resets
                    #     print(f"[DEBUG] Reset {self._demo_index}: pos_error={pos_error*1000:.2f}mm, iterations={iteration+1}")

                    # # Only print warnings if there are errors
                    # if pos_error > 0.02:  # 2cm threshold
                    #     print(f"WARNING: Start position error {pos_error*100:.1f}cm - Expected: {self._fixed_start_pos}, Got: {final_pos}")
                    # if target_error > 0.01:  # 1cm threshold
                    #     print(f"WARNING: Target position error {target_error*100:.1f}cm - Expected: {self._fixed_target_pos}, Got: {target_pos}")

                    self._demo_index += 1

                except Exception as e:
                    print(f"ERROR: Could not set fixed positions: {e}")
                    import traceback
                    traceback.print_exc()
            else:
                # Standard demo reset
                demo_path = self._demo_paths[self._demo_index % len(self._demo_paths)]
                self._demo_index += 1
                with open(demo_path, "rb") as f:
                    demo = pickle.load(f)
                _, obs = self.task.reset_to_demo(demo)
        else:
            # Check if this is close_drawer task - needs special reset
            if self.task_name == 'close_drawer':
                # Close drawer specific reset (matching data collection exactly)
                # 1. Reset robot to default first (to avoid collision during task reset)
                reset_robot_to_default(self.task)
                # 2. Reset task environment
                self.task.set_variation(self.variation)
                _, obs = self.task.reset()
                # 3. Fix cabinet position and orientation
                fix_cabinet_orientation(self.task)
                # 4. Set drawer to open position
                set_drawer_open(self.task, CLOSE_DRAWER_CONFIG['drawer_variation'],
                               CLOSE_DRAWER_CONFIG['drawer_open_amount'])
                # 5. Set robot to proper HOME_JOINTS (lower EE to drawer height)
                set_robot_home_joints(self.task, CLOSE_DRAWER_CONFIG['home_joints'])
                # 6. Get handle position for contact checking
                # Use same method as dataset generator: get waypoint0 AFTER drawer is opened
                waypoint = Dummy('waypoint0')
                waypoint_pos = np.array(waypoint.get_position())
                # waypoint0 is at the open drawer handle position after set_drawer_open
                self._handle_pos = waypoint_pos.copy()

                # 7. Track drawer joint for movement detection
                drawer_names = ["bottom", "middle", "top"]
                drawer_name = drawer_names[CLOSE_DRAWER_CONFIG['drawer_variation']]
                self._drawer_joint = Joint(f'drawer_joint_{drawer_name}')
                self._initial_drawer_pos = self._drawer_joint.get_joint_position()
                self._drawer_started_moving = False
                self._contact_ee_pos = None
                self._prev_ee_pos = None  # Reset previous EE position

                # Get updated observation after setup
                obs = self.task._scene.get_observation()
            elif self.task_name == 'pick_up_cup' and self.fixed_cup_position is not None:
                # Grasp (pick_up_cup) task with fixed cup position
                # 1. Reset robot to default first
                reset_robot_to_default(self.task)
                # 2. Reset task environment
                self.task.set_variation(self.variation)
                _, obs = self.task.reset()
                # 3. Set cup to fixed position
                set_cup_position(self.task, self.fixed_cup_position)
                # 4. Set robot to HOME_JOINTS for grasp task
                set_robot_home_joints(self.task, GRASP_CONFIG['home_joints'])
                # 5. Get updated observation after setup
                obs = self.task._scene.get_observation()
            else:
                self.task.set_variation(self.variation)
                _, obs = self.task.reset()
        state = extract_state_from_obs(obs)

        # Record initial frame if video recording is enabled
        if self.record_video and hasattr(obs, 'front_rgb'):
            self.video_frames.append(obs.front_rgb)

        # Store initial EE position for trajectory tracking
        self._initial_ee_pos = state[15:18].copy()  # Gripper position at indices 15:18

        return state

    def step(self, action):
        """
        Step environment with action
        Returns: obs, reward, done, info (old gym API - 4 values)
        """
        # Clip action to valid range
        action = np.clip(action, self.action_low, self.action_high)

        # Execute action with optional repetition to reach the target command
        reward = 0.0
        terminate = False
        obs = None
        success_this_step = False
        # Track EE positions at EVERY simulation step for smooth trajectory visualization
        primitive_ee_positions = []
        for _ in range(self.action_repeat):
            obs, step_reward, step_terminate = self.task.step(action)
            # Track success from any step
            if step_reward == 1.0:
                success_this_step = True
                if not self._success_recorded:
                    self._success_recorded = True
                    self._success_step = self._step_count  # Record step when success happened
                    reward = 1.0  # Only record reward on first success
            # Collect all intermediate EE positions from convergence steps
            if hasattr(self, 'custom_joint_position') and hasattr(self.custom_joint_position, 'intermediate_ee_positions'):
                primitive_ee_positions.extend(self.custom_joint_position.intermediate_ee_positions)

            # Note: Drawer movement detection code removed - contact checking now uses
            # trajectory-based approach (checking if EE got close to handle at any point)

            # In collect_full_trajectory mode, don't break on success
            if step_terminate:
                if self.collect_full_trajectory:
                    # Continue executing but track that termination happened
                    terminate = True
                else:
                    # Normal mode: break on terminate
                    terminate = True
                    break
        state = extract_state_from_obs(obs)

        # Record frame if video recording is enabled
        if self.record_video and hasattr(obs, 'front_rgb'):
            self.video_frames.append(obs.front_rgb)

        # In collect_full_trajectory mode, don't terminate early - execute all actions
        # This allows checking if trajectory reaches handle even if drawer closes early
        if self.collect_full_trajectory:
            done = False  # Let wrapper handle max_episode_steps
        else:
            done = terminate  # In old gym API, done = terminated OR truncated

        self._step_count += 1

        # Current EE position (gripper_pose xyz at indices 15:18)
        current_ee_pos = state[15:18].copy()

        info = {
            'success': self._success_recorded,  # True if success was ever recorded
            'success_this_step': success_this_step,  # True if success happened this step
            'success_step': self._success_step,  # Step at which success was first recorded (for plotting)
            'full_obs': {'state': state},
            'primitive_ee_positions': primitive_ee_positions,  # EE positions at each primitive timestep
            'current_ee_pos': current_ee_pos,  # Current EE position (for trajectory visualization)
            'initial_ee_pos': self._initial_ee_pos if hasattr(self, '_initial_ee_pos') else None,  # Initial position from reset
            'actions_executed': self._step_count,  # Number of actions executed so far (1-indexed)
            'handle_pos': self._handle_pos,  # Handle position for contact checking (close_drawer only)
            'contact_ee_pos': self._contact_ee_pos,  # EE position when drawer started moving (for contact check)
            'drawer_moved': self._drawer_started_moving,  # True if drawer started closing
        }
        return state, reward, done, info

    def _save_video(self):
        """Save recorded frames to video file"""
        if len(self.video_frames) > 0 and self.video_path is not None:
            try:
                # Convert frames to uint8 if needed
                frames = []
                for frame in self.video_frames:
                    if frame.dtype == np.float32 or frame.dtype == np.float64:
                        # Assume frames are in [0, 1] range
                        frame = (frame * 255).astype(np.uint8)
                    frames.append(frame)

                # Save video
                imageio.mimsave(self.video_path, frames, fps=30)
                print(f"Video saved to {self.video_path}")
            except Exception as e:
                print(f"Failed to save video: {e}")

    def close(self):
        """Close environment"""
        # Save any remaining video frames
        if self.record_video and len(self.video_frames) > 0 and self.video_path is not None:
            self._save_video()
        self.env.shutdown()
