"""
Evaluate pick-and-place diffusion policy in RLBench simulation.

Full pipeline evaluation:
1. REACH phase: Diffusion policy (home -> pregrasp), conditioned on reach CP
2. DESCEND + GRASP: Hard-coded (pregrasp -> grasp)
3. LIFT: Hard-coded (grasp -> lift)
4. CARRY phase: Diffusion policy (lift -> prerelease), conditioned on carry CP
5. DESCEND_RELEASE + RELEASE: Hard-coded (prerelease -> release)

Reports:
- Reach EE error (final EE to pregrasp target)
- Carry EE error (final EE to prerelease target)
- Grasp success rate
- Lift success rate
- Release success rate
- Task success rate
"""

import sys
import os

# Add RLBench_pick_place to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../RLBench_pick_place/make_dataset'))

import logging
import numpy as np
import torch
import hydra
from pathlib import Path
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from rlbench.environment import Environment
from rlbench.backend.conditions import DetectedSeveralCondition, NothingGrasped
from pyrep.objects.shape import Shape
from pyrep.objects.proximity_sensor import ProximitySensor

from utils import (
    create_obs_config, create_action_mode, get_task_classes,
    HOME_JOINTS,
)

from grasp_and_lift import descend_and_grasp, lift, descend_and_release
from pick_and_place_utils import align_robot_to_carry_start

log = logging.getLogger(__name__)


class EvalPickPlaceReach:
    """
    Evaluation agent for pick-and-place using RLBench simulation.

    Evaluates diffusion policy on both REACH and CARRY phases:
    - REACH: home -> pregrasp (64 steps, learned)
    - DESCEND + GRASP: pregrasp -> grasp (hard-coded)
    - LIFT: grasp -> lift (hard-coded)
    - CARRY: lift -> prerelease (64 steps, learned)
    - DESCEND_RELEASE + RELEASE: prerelease -> release (hard-coded)

    Control points for REACH and CARRY are sampled independently.
    """

    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg.device
        self.seed = cfg.seed

        # Set random seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)

        # Create log directory
        self.logdir = Path(cfg.logdir)
        self.logdir.mkdir(parents=True, exist_ok=True)
        log.info(f"Logging to {self.logdir}")

        # Load normalization stats
        self._load_normalization(cfg.normalization_path)

        # Load REACH dataset and metadata
        self._load_reach_dataset(cfg.dataset_path, cfg.metadata_path, cfg.full_metadata_path)

        # Load CARRY dataset and metadata
        self._load_carry_dataset(cfg.carry_dataset_path, cfg.carry_metadata_path)

        # Build diffusion model
        self.model = hydra.utils.instantiate(cfg.model)

        # Evaluation settings
        self.n_rollouts_per_cp = cfg.n_rollouts_per_cp
        self.use_random_noise = cfg.get('use_random_noise', True)
        self.horizon_steps = cfg.horizon_steps
        self.action_repeat = cfg.get('action_repeat', 5)  # Physics steps per action

        # RLBench environment (initialized in run())
        self.rlbench_env = None
        self.task_env = None

    def _load_normalization(self, path):
        """Load normalization statistics (min/max for [-1, 1] scaling)."""
        data = np.load(path)
        self.obs_min = data['obs_min']
        self.obs_max = data['obs_max']
        self.action_min = data['action_min']
        self.action_max = data['action_max']
        log.info(f"Loaded normalization from {path}")

    def _load_reach_dataset(self, dataset_path, metadata_path, full_metadata_path):
        """Load reach phase dataset and metadata."""
        # Load normalized dataset
        data = np.load(dataset_path)
        self.reach_states = data['states']
        self.reach_actions = data['actions']
        self.reach_traj_lengths = data['traj_lengths']
        self.n_reach_trajectories = len(self.reach_traj_lengths)

        # Load processed metadata (canonical params, end_pos)
        self.reach_metadata = np.load(metadata_path, allow_pickle=True)

        # Load full metadata (for world coordinates)
        self.full_metadata = np.load(full_metadata_path, allow_pickle=True)

        log.info(f"Loaded {self.n_reach_trajectories} REACH trajectories from {dataset_path}")
        log.info(f"REACH States shape: {self.reach_states.shape}, Actions shape: {self.reach_actions.shape}")

    def _load_carry_dataset(self, dataset_path, metadata_path):
        """Load carry phase dataset and metadata."""
        # Load normalized dataset
        data = np.load(dataset_path)
        self.carry_states = data['states']
        self.carry_actions = data['actions']
        self.carry_traj_lengths = data['traj_lengths']
        self.n_carry_trajectories = len(self.carry_traj_lengths)

        # Load processed metadata (canonical params, end_pos for carry = prerelease)
        self.carry_metadata = np.load(metadata_path, allow_pickle=True)

        log.info(f"Loaded {self.n_carry_trajectories} CARRY trajectories from {dataset_path}")
        log.info(f"CARRY States shape: {self.carry_states.shape}, Actions shape: {self.carry_actions.shape}")

    def _select_cps_by_varying_param(self, metadata, mode, fixed_val1, fixed_val2):
        """Select CPs varying one parameter while keeping others fixed.

        Args:
            metadata: dataset metadata
            mode: 'angle', 'dist', or 'pos' - which parameter to vary
            fixed_val1: first fixed parameter value
            fixed_val2: second fixed parameter value

        For mode='angle': fixed_val1=dist_frac, fixed_val2=pos_frac
        For mode='dist':  fixed_val1=angle, fixed_val2=pos_frac
        For mode='pos':   fixed_val1=angle, fixed_val2=dist_frac

        Returns:
            List of (index, varying_param_value) tuples sorted by varying param
        """
        indices_with_value = []
        for i, m in enumerate(metadata):
            angle, dist, pos = m['canonical_cp_params']

            if mode == 'angle':
                # Vary angle, fix dist and pos
                if abs(dist - fixed_val1) < 0.01 and abs(pos - fixed_val2) < 0.01:
                    indices_with_value.append((i, angle))
            elif mode == 'dist':
                # Vary dist, fix angle and pos
                if abs(angle - fixed_val1) < 0.01 and abs(pos - fixed_val2) < 0.01:
                    indices_with_value.append((i, dist))
            elif mode == 'pos':
                # Vary pos, fix angle and dist
                if abs(angle - fixed_val1) < 0.01 and abs(dist - fixed_val2) < 0.01:
                    indices_with_value.append((i, pos))

        # Sort by varying parameter value
        indices_with_value.sort(key=lambda x: x[1])
        return indices_with_value

    def _get_gt_ee_trajectory(self, phase, traj_idx):
        """Extract ground truth EE trajectory from dataset.

        Args:
            phase: 'reach' or 'carry'
            traj_idx: trajectory index

        Returns:
            ee_traj: (T, 3) denormalized EE positions
        """
        if phase == 'reach':
            states = self.reach_states
            traj_lengths = self.reach_traj_lengths
        else:
            states = self.carry_states
            traj_lengths = self.carry_traj_lengths

        # Get trajectory states
        start_idx = sum(traj_lengths[:traj_idx])
        end_idx = start_idx + traj_lengths[traj_idx]
        traj_states = states[start_idx:end_idx]

        # Extract EE position (indices 15:18)
        ee_norm = traj_states[:, 15:18]

        # Denormalize
        ee_denorm = (ee_norm + 1) * (self.obs_max[15:18] - self.obs_min[15:18]) / 2 + self.obs_min[15:18]

        return ee_denorm

    def _get_reach_trajectory(self, traj_idx):
        """Get states and actions for a specific reach trajectory."""
        start_idx = sum(self.reach_traj_lengths[:traj_idx])
        end_idx = start_idx + self.reach_traj_lengths[traj_idx]
        states = self.reach_states[start_idx:end_idx]
        actions = self.reach_actions[start_idx:end_idx]
        return states, actions

    def _get_carry_trajectory(self, traj_idx):
        """Get states and actions for a specific carry trajectory."""
        start_idx = sum(self.carry_traj_lengths[:traj_idx])
        end_idx = start_idx + self.carry_traj_lengths[traj_idx]
        states = self.carry_states[start_idx:end_idx]
        actions = self.carry_actions[start_idx:end_idx]
        return states, actions

    def _denormalize_action(self, action_norm):
        """Denormalize action from [-1, 1] to original scale using min/max."""
        return (action_norm + 1) * (self.action_max - self.action_min) / 2 + self.action_min

    def _normalize_state(self, state):
        """Normalize state to [-1, 1] using min/max."""
        return 2 * (state - self.obs_min) / (self.obs_max - self.obs_min) - 1

    def _setup_rlbench(self):
        """Setup RLBench environment (same as data collection)."""
        log.info("Setting up RLBench environment...")

        obs_config = create_obs_config(save_video=False, include_task_low_dim=True)
        action_mode = create_action_mode(joint_action_mode=False)

        self.rlbench_env = Environment(
            action_mode=action_mode, obs_config=obs_config, headless=True
        )
        self.rlbench_env.launch()

        task_classes = get_task_classes(["stack_blocks"])
        self.task_env = self.rlbench_env.get_task(task_classes[0])
        self.task_env.set_variation(0)

        log.info("RLBench environment ready")

    def _reset_to_home(self):
        """Reset robot to home position."""
        robot = self.task_env._scene.robot
        gripper = robot.gripper

        # Set to home joints
        robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        gripper.actuate(1.0, 0.2)  # Open gripper

        for _ in range(20):
            self.task_env._scene.pyrep.step()

        # Fix object position (same as data collection)
        FIXED_OBJECT_POS = np.array([0.500925600528717, -0.04781968146562576, 0.7750056982040405])
        target_block = Shape('stack_blocks_target0')
        target_block.set_position(FIXED_OBJECT_POS)

        # Hide other blocks
        for i in range(1, 4):
            try:
                other_block = Shape(f'stack_blocks_target{i}')
                other_block.set_position([10, 10, 0])
            except:
                pass

        # Override success condition (only require 1 block, same as data collection)
        success_detector = ProximitySensor('stack_blocks_success')
        self.task_env._task._success_conditions = [
            DetectedSeveralCondition([target_block], success_detector, 1),
            NothingGrasped(gripper)
        ]

        for _ in range(10):
            self.task_env._scene.pyrep.step()

    def _get_ee_position(self):
        """Get current end-effector position in world coordinates."""
        tip = self.task_env._scene.robot.arm.get_tip()
        return np.array(tip.get_position())

    def _get_current_state(self):
        """Get current observation state (normalized).

        State format (from process_pick_place_dataset.py):
        - joint_positions: 7 dims
        - joint_velocities: 7 dims
        - gripper_open: 1 dim
        - gripper_pose position: 3 dims
        - gripper_pose quaternion: 4 dims
        Total: 22 dims
        """
        obs = self.task_env._scene.get_observation()

        state = np.concatenate([
            np.array(obs.joint_positions, dtype=np.float32),        # 7
            np.array(obs.joint_velocities, dtype=np.float32),       # 7
            np.array([obs.gripper_open], dtype=np.float32),         # 1
            np.array(obs.gripper_pose[:3], dtype=np.float32),       # position: 3
            np.array(obs.gripper_pose[3:], dtype=np.float32),       # quaternion: 4
        ])

        # Normalize
        state_norm = self._normalize_state(state)
        return state_norm

    def _execute_action(self, action_norm):
        """Execute a single action in simulation.

        Action format (from data collection): joint_positions (7) + gripper_open (1)
        The data collection uses set_joint_positions, not velocities.
        """
        robot = self.task_env._scene.robot

        # Denormalize action
        action = self._denormalize_action(action_norm)

        # Action format: joint positions (7) + gripper (1)
        joint_positions = action[:7]
        gripper_action = action[7] if len(action) > 7 else 1.0

        # Apply action (same as data collection)
        robot.arm.set_joint_positions(joint_positions, disable_dynamics=True)
        robot.arm.set_joint_target_velocities([0] * 7)
        robot.gripper.actuate(gripper_action, 0.2)

        # Step simulation multiple times (configurable, default=5 like data collection)
        for _ in range(self.action_repeat):
            self.task_env._scene.pyrep.step()

    def run(self):
        """Run full pick-and-place evaluation in RLBench simulation."""
        log.info(f"\n{'='*80}")
        log.info(f"FULL PICK-AND-PLACE EVALUATION (RLBench Simulation)")
        log.info(f"{'='*80}")
        log.info(f"REACH trajectories: {self.n_reach_trajectories}")
        log.info(f"CARRY trajectories: {self.n_carry_trajectories}")

        # Setup RLBench
        self._setup_rlbench()

        # Select control points based on strategy
        cp_strategy = self.cfg.get('cp_selection_strategy', 'random')
        fixed_angle = self.cfg.get('fixed_angle', 0.4)
        fixed_dist = self.cfg.get('fixed_dist_frac', 0.4)
        fixed_pos = self.cfg.get('fixed_pos_frac', 0.5)

        # Store varying parameter name for plot labels
        self.varying_param = None

        if cp_strategy == 'vary_angle':
            # Vary angle, fix dist and pos
            # Expected: 5 CPs with angles [0.0, 0.2, 0.4, 0.6, 0.8]
            reach_cps = self._select_cps_by_varying_param(self.reach_metadata, 'angle', fixed_dist, fixed_pos)
            carry_cps = self._select_cps_by_varying_param(self.carry_metadata, 'angle', fixed_dist, fixed_pos)
            self.varying_param = 'angle'

            log.info(f"\n--- CP Selection: vary_angle (fixed dist={fixed_dist}, pos={fixed_pos}) ---")
            log.info(f"REACH CPs ({len(reach_cps)}):")
            for idx, val in reach_cps:
                log.info(f"  [{idx}]: angle={val:.2f}")
            log.info(f"CARRY CPs ({len(carry_cps)}):")
            for idx, val in carry_cps:
                log.info(f"  [{idx}]: angle={val:.2f}")

        elif cp_strategy == 'vary_dist':
            # Vary dist, fix angle and pos
            # Expected: 5 CPs with dist_frac [0.2, 0.4, 0.6, 0.8, 1.0]
            reach_cps = self._select_cps_by_varying_param(self.reach_metadata, 'dist', fixed_angle, fixed_pos)
            carry_cps = self._select_cps_by_varying_param(self.carry_metadata, 'dist', fixed_angle, fixed_pos)
            self.varying_param = 'dist'

            log.info(f"\n--- CP Selection: vary_dist (fixed angle={fixed_angle}, pos={fixed_pos}) ---")
            log.info(f"REACH CPs ({len(reach_cps)}):")
            for idx, val in reach_cps:
                log.info(f"  [{idx}]: dist={val:.2f}")
            log.info(f"CARRY CPs ({len(carry_cps)}):")
            for idx, val in carry_cps:
                log.info(f"  [{idx}]: dist={val:.2f}")

        elif cp_strategy == 'vary_pos':
            # Vary pos, fix angle and dist
            # Expected: 3 CPs with pos_frac [0.35, 0.5, 0.65]
            reach_cps = self._select_cps_by_varying_param(self.reach_metadata, 'pos', fixed_angle, fixed_dist)
            carry_cps = self._select_cps_by_varying_param(self.carry_metadata, 'pos', fixed_angle, fixed_dist)
            self.varying_param = 'pos'

            log.info(f"\n--- CP Selection: vary_pos (fixed angle={fixed_angle}, dist={fixed_dist}) ---")
            log.info(f"REACH CPs ({len(reach_cps)}):")
            for idx, val in reach_cps:
                log.info(f"  [{idx}]: pos={val:.2f}")
            log.info(f"CARRY CPs ({len(carry_cps)}):")
            for idx, val in carry_cps:
                log.info(f"  [{idx}]: pos={val:.2f}")

        else:
            # Random selection (original behavior)
            n_random = self.cfg.get('n_control_points', 5)
            reach_indices = np.random.choice(
                np.arange(self.n_reach_trajectories),
                size=n_random,
                replace=False
            )
            carry_indices = np.random.choice(
                np.arange(self.n_carry_trajectories),
                size=n_random,
                replace=False
            )
            reach_cps = [(idx, None) for idx in reach_indices]
            carry_cps = [(idx, None) for idx in carry_indices]
            self.varying_param = None

            log.info(f"\n--- Selected REACH CPs (random): {reach_indices} ---")
            log.info(f"--- Selected CARRY CPs (random): {carry_indices} ---")

        # Extract indices and values
        reach_indices = np.array([idx for idx, val in reach_cps])
        carry_indices = np.array([idx for idx, val in carry_cps])
        reach_varying_vals = [val for idx, val in reach_cps]
        carry_varying_vals = [val for idx, val in carry_cps]

        # Store for plotting
        self.reach_varying_vals = reach_varying_vals
        self.carry_varying_vals = carry_varying_vals

        n_control_points = len(reach_indices)
        log.info(f"\nEvaluating {n_control_points} control point pairs, {self.n_rollouts_per_cp} rollouts each")
        log.info(f"Action repeat: {self.action_repeat} (physics steps per action)")
        log.info(f"DDIM steps: {self.cfg.ddim_steps}")

        # Storage for results
        all_results = []
        all_ee_trajectories = []

        try:
            for cp_idx in range(n_control_points):
                reach_traj_idx = reach_indices[cp_idx]
                carry_traj_idx = carry_indices[cp_idx]

                log.info(f"\n{'='*60}")
                log.info(f"Evaluation {cp_idx + 1}/{n_control_points}")
                log.info(f"  REACH CP[{reach_traj_idx}], CARRY CP[{carry_traj_idx}]")
                log.info(f"{'='*60}")

                # Get REACH metadata
                reach_meta = self.reach_metadata[reach_traj_idx]
                full_meta = self.full_metadata[reach_traj_idx]  # Contains world positions
                reach_cp_params = np.array(reach_meta['canonical_cp_params'])
                reach_end_pos = np.array(reach_meta['end_pos'])  # pregrasp position

                # Get CARRY metadata
                carry_meta = self.carry_metadata[carry_traj_idx]
                carry_cp_params = np.array(carry_meta['canonical_cp_params'])
                carry_end_pos = np.array(carry_meta['end_pos'])  # prerelease position

                # World positions (from full_metadata, same for all trajs since fixed object)
                home_pos = np.array(full_meta['home_pos'])
                pregrasp_pos = np.array(full_meta['pregrasp_pos'])
                grasp_pos = np.array(full_meta['grasp_pos'])
                lift_pos = np.array(full_meta['lift_pos'])
                prerelease_pos = np.array(full_meta['prerelease_pos'])
                release_pos = np.array(full_meta['release_pos'])
                orientation = full_meta['orientation']

                log.info(f"  REACH z: angle={reach_cp_params[0]:.3f}, dist={reach_cp_params[1]:.3f}, pos={reach_cp_params[2]:.3f}")
                log.info(f"  CARRY z: angle={carry_cp_params[0]:.3f}, dist={carry_cp_params[1]:.3f}, pos={carry_cp_params[2]:.3f}")
                log.info(f"  Pregrasp target: {pregrasp_pos}")
                log.info(f"  Prerelease target: {prerelease_pos}")

                # Get ground truth EE trajectories from dataset
                gt_reach_traj = self._get_gt_ee_trajectory('reach', reach_traj_idx)
                gt_carry_traj = self._get_gt_ee_trajectory('carry', carry_traj_idx)

                # Run rollouts
                cp_reach_errors = []
                cp_carry_errors = []
                cp_grasp_successes = []
                cp_lift_successes = []
                cp_release_successes = []
                cp_task_successes = []
                cp_trajectories = []

                for rollout_idx in range(self.n_rollouts_per_cp):
                    # Reset environment
                    np.random.seed(42)
                    self.task_env.reset()
                    self._reset_to_home()

                    # Get initial state and EE position
                    init_state = self._get_current_state()
                    init_ee = self._get_ee_position()

                    # ========================================================
                    # Phase 1: REACH (diffusion policy)
                    # ========================================================
                    # REACH expects gripper open (1.0), use multi-sample selection
                    n_samples = self.cfg.get('n_samples_per_phase', 10)
                    reach_actions = self._run_policy(
                        init_state[np.newaxis, :],
                        reach_cp_params,
                        reach_end_pos,
                        use_random_noise=self.use_random_noise,
                        expected_gripper=1.0,  # REACH: gripper open
                        n_samples=n_samples
                    )

                    # Execute REACH actions
                    reach_trajectory = [init_ee.copy()]
                    for step_idx in range(self.horizon_steps):
                        self._execute_action(reach_actions[step_idx])
                        reach_trajectory.append(self._get_ee_position().copy())
                    reach_trajectory = np.array(reach_trajectory)

                    final_reach_ee = reach_trajectory[-1]
                    reach_error = np.linalg.norm(final_reach_ee - pregrasp_pos)
                    log.info(f"  Rollout {rollout_idx + 1}/{self.n_rollouts_per_cp}: REACH error={reach_error:.4f}m")

                    # ========================================================
                    # Phase 2 & 3: DESCEND + GRASP (hard-coded)
                    # ========================================================
                    grasp_result = descend_and_grasp(
                        self.task_env,
                        pregrasp_pos=pregrasp_pos,
                        grasp_pos=grasp_pos,
                        orientation=orientation,
                        object_shape_name='stack_blocks_target0',
                        descend_steps=8,
                        grasp_steps=8,
                        steps_per_point=5,
                        capture_video=False,
                        verbose=False
                    )
                    grasp_success = grasp_result['grasp_success']
                    log.info(f"    GRASP success: {grasp_success}")

                    # ========================================================
                    # Phase 4: LIFT (hard-coded)
                    # ========================================================
                    lift_result = lift(
                        self.task_env,
                        grasp_pos=grasp_pos,
                        lift_pos=lift_pos,
                        orientation=orientation,
                        object_shape_name='stack_blocks_target0',
                        lift_steps=8,
                        steps_per_point=5,
                        capture_video=False,
                        verbose=False,
                        prev_joints=grasp_result['prev_joints']
                    )
                    object_lifted = lift_result['object_lifted']
                    log.info(f"    LIFT success: {object_lifted}")

                    # ========================================================
                    # Phase 5: CARRY (diffusion policy)
                    # ========================================================
                    # CRITICAL: Align robot to expected LIFT position before CARRY
                    # This ensures the input state matches training data expectations
                    # The diffusion model expects CARRY to start from a specific pose
                    align_success, _ = align_robot_to_carry_start(
                        self.task_env,
                        target_ee_pos=lift_pos,
                        orientation=orientation,
                        settle_steps=10
                    )
                    if not align_success:
                        log.warning("    Failed to align robot to CARRY start pose")

                    # Get current state after alignment
                    carry_init_state = self._get_current_state()
                    carry_init_ee = self._get_ee_position()

                    # CARRY expects gripper closed (-1.0), use multi-sample selection
                    carry_actions = self._run_policy(
                        carry_init_state[np.newaxis, :],
                        carry_cp_params,
                        carry_end_pos,  # prerelease position
                        use_random_noise=self.use_random_noise,
                        expected_gripper=-1.0,  # CARRY: gripper closed
                        n_samples=n_samples
                    )

                    # Debug: print first and last carry actions
                    if rollout_idx == 0:
                        log.info(f"    CARRY DEBUG: init_ee={carry_init_ee}")
                        log.info(f"    CARRY DEBUG: target={carry_end_pos}")
                        log.info(f"    CARRY DEBUG: z={carry_cp_params}")
                        log.info(f"    CARRY DEBUG: first action (norm)={carry_actions[0, :4]}...")
                        log.info(f"    CARRY DEBUG: last action (norm)={carry_actions[-1, :4]}...")
                        first_raw = self._denormalize_action(carry_actions[0])
                        last_raw = self._denormalize_action(carry_actions[-1])
                        log.info(f"    CARRY DEBUG: first action (raw joints)={first_raw[:7]}")
                        log.info(f"    CARRY DEBUG: last action (raw joints)={last_raw[:7]}")

                    # Execute CARRY actions
                    carry_trajectory = [carry_init_ee.copy()]
                    for step_idx in range(self.horizon_steps):
                        self._execute_action(carry_actions[step_idx])
                        carry_trajectory.append(self._get_ee_position().copy())
                    carry_trajectory = np.array(carry_trajectory)

                    final_carry_ee = carry_trajectory[-1]
                    carry_error = np.linalg.norm(final_carry_ee - prerelease_pos)
                    log.info(f"    CARRY error={carry_error:.4f}m (target: {prerelease_pos})")

                    # ========================================================
                    # Phase 6 & 7: DESCEND_RELEASE + RELEASE (hard-coded)
                    # ========================================================
                    release_result = descend_and_release(
                        self.task_env,
                        prerelease_pos=prerelease_pos,
                        release_pos=release_pos,
                        orientation=orientation,
                        object_shape_name='stack_blocks_target0',
                        descend_steps=8,
                        release_steps=8,
                        steps_per_point=5,
                        capture_video=False,
                        verbose=False,
                        prev_joints=None  # Will get current joints internally
                    )
                    object_released = release_result['object_released']
                    log.info(f"    RELEASE success: {object_released}")

                    # Check task success
                    task_success, _ = self.task_env._task.success()
                    log.info(f"    TASK success: {task_success}")

                    # Store results
                    cp_reach_errors.append(reach_error)
                    cp_carry_errors.append(carry_error)
                    cp_grasp_successes.append(grasp_success)
                    cp_lift_successes.append(object_lifted)
                    cp_release_successes.append(object_released)
                    cp_task_successes.append(task_success)
                    cp_trajectories.append({
                        'reach': reach_trajectory,
                        'carry': carry_trajectory
                    })

                # Stats for this CP pair
                mean_reach_error = np.mean(cp_reach_errors)
                std_reach_error = np.std(cp_reach_errors)
                mean_carry_error = np.mean(cp_carry_errors)
                std_carry_error = np.std(cp_carry_errors)
                grasp_rate = np.mean(cp_grasp_successes)
                lift_rate = np.mean(cp_lift_successes)
                release_rate = np.mean(cp_release_successes)
                task_rate = np.mean(cp_task_successes)

                log.info(f"\n  Summary (REACH[{reach_traj_idx}], CARRY[{carry_traj_idx}]):")
                log.info(f"    Reach error:  {mean_reach_error:.4f}m ± {std_reach_error:.4f}m")
                log.info(f"    Carry error:  {mean_carry_error:.4f}m ± {std_carry_error:.4f}m")
                log.info(f"    Grasp rate:   {grasp_rate*100:.0f}%")
                log.info(f"    Lift rate:    {lift_rate*100:.0f}%")
                log.info(f"    Release rate: {release_rate*100:.0f}%")
                log.info(f"    Task rate:    {task_rate*100:.0f}%")

                all_results.append({
                    'reach_traj_idx': reach_traj_idx,
                    'carry_traj_idx': carry_traj_idx,
                    'reach_cp_params': reach_cp_params.tolist(),
                    'carry_cp_params': carry_cp_params.tolist(),
                    'pregrasp_pos': pregrasp_pos.tolist(),
                    'prerelease_pos': prerelease_pos.tolist(),
                    'reach_errors': cp_reach_errors,
                    'carry_errors': cp_carry_errors,
                    'grasp_successes': cp_grasp_successes,
                    'lift_successes': cp_lift_successes,
                    'release_successes': cp_release_successes,
                    'task_successes': cp_task_successes,
                    'mean_reach_error': mean_reach_error,
                    'std_reach_error': std_reach_error,
                    'mean_carry_error': mean_carry_error,
                    'std_carry_error': std_carry_error,
                    'grasp_rate': grasp_rate,
                    'lift_rate': lift_rate,
                    'release_rate': release_rate,
                    'task_rate': task_rate,
                    'gt_reach_traj': gt_reach_traj,  # Ground truth REACH trajectory
                    'gt_carry_traj': gt_carry_traj,  # Ground truth CARRY trajectory
                })
                all_ee_trajectories.append(cp_trajectories)

        finally:
            # Cleanup
            if self.rlbench_env is not None:
                self.rlbench_env.shutdown()

        # Overall statistics
        all_reach_errors = [e for r in all_results for e in r['reach_errors']]
        all_carry_errors = [e for r in all_results for e in r['carry_errors']]
        all_grasp_successes = [s for r in all_results for s in r['grasp_successes']]
        all_lift_successes = [s for r in all_results for s in r['lift_successes']]
        all_release_successes = [s for r in all_results for s in r['release_successes']]
        all_task_successes = [s for r in all_results for s in r['task_successes']]

        overall_reach_mean = np.mean(all_reach_errors)
        overall_reach_std = np.std(all_reach_errors)
        overall_carry_mean = np.mean(all_carry_errors)
        overall_carry_std = np.std(all_carry_errors)
        overall_grasp_rate = np.mean(all_grasp_successes)
        overall_lift_rate = np.mean(all_lift_successes)
        overall_release_rate = np.mean(all_release_successes)
        overall_task_rate = np.mean(all_task_successes)

        log.info(f"\n{'='*80}")
        log.info(f"OVERALL RESULTS")
        log.info(f"{'='*80}")
        log.info(f"Reach EE error (to pregrasp):   {overall_reach_mean:.4f}m ± {overall_reach_std:.4f}m")
        log.info(f"Carry EE error (to prerelease): {overall_carry_mean:.4f}m ± {overall_carry_std:.4f}m")
        log.info(f"Grasp success rate:   {overall_grasp_rate*100:.1f}%")
        log.info(f"Lift success rate:    {overall_lift_rate*100:.1f}%")
        log.info(f"Release success rate: {overall_release_rate*100:.1f}%")
        log.info(f"Task success rate:    {overall_task_rate*100:.1f}%")

        # Save results
        results_path = self.logdir / 'evaluation_results.npy'
        np.save(results_path, {
            'all_results': all_results,
            'overall_reach_mean': overall_reach_mean,
            'overall_reach_std': overall_reach_std,
            'overall_carry_mean': overall_carry_mean,
            'overall_carry_std': overall_carry_std,
            'overall_grasp_rate': overall_grasp_rate,
            'overall_lift_rate': overall_lift_rate,
            'overall_release_rate': overall_release_rate,
            'overall_task_rate': overall_task_rate,
            'n_control_points': n_control_points,
            'n_rollouts_per_cp': self.n_rollouts_per_cp,
            'reach_indices': reach_indices.tolist(),
            'carry_indices': carry_indices.tolist(),
            'cp_selection_strategy': cp_strategy,
            'varying_param': self.varying_param,
            'reach_varying_vals': self.reach_varying_vals,
            'carry_varying_vals': self.carry_varying_vals,
        }, allow_pickle=True)
        log.info(f"\nSaved results to {results_path}")

        # Plot trajectories
        self._plot_trajectories(all_ee_trajectories, all_results)

        return {
            'reach_error': (overall_reach_mean, overall_reach_std),
            'carry_error': (overall_carry_mean, overall_carry_std),
            'grasp_rate': overall_grasp_rate,
            'lift_rate': overall_lift_rate,
            'release_rate': overall_release_rate,
            'task_rate': overall_task_rate,
        }

    def _run_policy(self, init_state, canonical_params, end_pos, use_random_noise=True,
                    expected_gripper=None, n_samples=1):
        """
        Run policy to predict actions.

        Args:
            init_state: Initial normalized state (1, 22)
            canonical_params: Control point params [angle, dist_frac, pos_frac] normalized to [0,1]
            end_pos: Target end position in WORLD coordinates
            use_random_noise: Whether to use random noise initialization
            expected_gripper: float, expected gripper state (1.0 for REACH, -1.0 for CARRY)
                             If provided, samples n_samples times and selects best match
            n_samples: int, number of samples to draw when expected_gripper is provided

        Returns:
            predicted_actions: (horizon_steps, action_dim) in normalized space
        """
        with torch.no_grad():
            obs = torch.FloatTensor(init_state).to(self.device)

            # Pass z directly - training config has normalize_angle=False
            # The canonical_params from processed metadata are already in [0, 1] range
            # and training dataset does NOT apply additional normalization
            z = torch.FloatTensor(canonical_params).unsqueeze(0).to(self.device)
            target = torch.FloatTensor(end_pos).unsqueeze(0).to(self.device)

            # Set z in model
            self.model.current_z = z

            # Create condition dict
            cond = {
                "state": obs,
                "target": target,
            }

            if expected_gripper is not None and n_samples > 1:
                # Multi-sample selection: sample multiple times and pick best match
                samples = []
                grippers = []
                for seed in range(n_samples):
                    torch.manual_seed(seed)
                    sample = self.model(cond)
                    actions = sample.trajectories.squeeze(0).cpu().numpy()
                    samples.append(actions)
                    grippers.append(actions[0, 7])  # First timestep gripper

                # Select sample with gripper closest to expected
                if expected_gripper > 0:
                    best_idx = np.argmax(grippers)  # Most positive for REACH
                else:
                    best_idx = np.argmin(grippers)  # Most negative for CARRY
                actions = samples[best_idx]
                log.debug(f"Multi-sample: selected idx={best_idx}, gripper={grippers[best_idx]:.2f}")
            else:
                # Single sample
                if not use_random_noise:
                    torch.manual_seed(self.seed)
                sample = self.model(cond)
                actions = sample.trajectories.squeeze(0).cpu().numpy()

        return actions

    def _plot_trajectories(self, all_ee_trajectories, all_results):
        """Plot EE trajectories in 3D with 3 subplots: overall, REACH, CARRY.

        Shows:
        - Ground truth (GT) trajectories: Thinner solid lines
        - Evaluation trials: Dashed lines
        - Different colors for different CPs (varying by angle/dist/pos)
        - REACH and CARRY subplots are zoomed in to focus on trajectory shape differences
        """
        fig = plt.figure(figsize=(18, 6))

        ax1 = fig.add_subplot(131, projection='3d')  # Overall trajectory
        ax2 = fig.add_subplot(132, projection='3d')  # REACH phase only (zoomed)
        ax3 = fig.add_subplot(133, projection='3d')  # CARRY phase only (zoomed)

        # Use distinct colors for different CPs
        color_names = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # Blue, Orange, Green, Red, Purple

        # Get varying parameter name for labels
        varying_param = getattr(self, 'varying_param', 'angle')
        reach_varying_vals = getattr(self, 'reach_varying_vals', [None] * len(all_results))
        carry_varying_vals = getattr(self, 'carry_varying_vals', [None] * len(all_results))

        # Line widths: GT thinner, eval even thinner
        gt_linewidth = 1.2  # Thinner GT lines (was 3.5)
        eval_linewidth = 1.0  # Thinner eval lines (was 1.0)

        # Get reference waypoints from first result
        ref_home = None
        ref_pregrasp = None
        ref_lift = None
        ref_prerelease = None

        # Collect all trajectory points for zoom calculation
        all_reach_points = []
        all_carry_points = []

        for cp_idx, (trajectories, result) in enumerate(zip(all_ee_trajectories, all_results)):
            color = color_names[cp_idx % len(color_names)]
            reach_idx = result['reach_traj_idx']
            carry_idx = result['carry_traj_idx']
            pregrasp = np.array(result['pregrasp_pos'])
            prerelease = np.array(result['prerelease_pos'])

            # Get varying parameter value for labels
            reach_val = reach_varying_vals[cp_idx] if cp_idx < len(reach_varying_vals) else None
            carry_val = carry_varying_vals[cp_idx] if cp_idx < len(carry_varying_vals) else None

            # Create labels based on varying parameter
            if varying_param and reach_val is not None:
                reach_label = f'GT {varying_param}={reach_val:.2f}'
                carry_label = f'GT {varying_param}={carry_val:.2f}'
            else:
                reach_label = f'GT R[{reach_idx}]'
                carry_label = f'GT C[{carry_idx}]'

            # Get ground truth trajectories
            gt_reach = result.get('gt_reach_traj')
            gt_carry = result.get('gt_carry_traj')

            # Store reference waypoints from first result
            if cp_idx == 0:
                ref_pregrasp = pregrasp
                ref_prerelease = prerelease

            # ========== Plot GT trajectories (thinner solid lines) ==========
            if gt_reach is not None:
                ax1.plot(gt_reach[:, 0], gt_reach[:, 1], gt_reach[:, 2], '-',
                        color=color, alpha=1.0, linewidth=gt_linewidth, label=reach_label)
                ax2.plot(gt_reach[:, 0], gt_reach[:, 1], gt_reach[:, 2], '-',
                        color=color, alpha=1.0, linewidth=gt_linewidth, label=reach_label)
                all_reach_points.append(gt_reach)

                # Store home from GT
                if cp_idx == 0:
                    ref_home = gt_reach[0]

            if gt_carry is not None:
                ax1.plot(gt_carry[:, 0], gt_carry[:, 1], gt_carry[:, 2], '-',
                        color=color, alpha=1.0, linewidth=gt_linewidth)
                ax3.plot(gt_carry[:, 0], gt_carry[:, 1], gt_carry[:, 2], '-',
                        color=color, alpha=1.0, linewidth=gt_linewidth, label=carry_label)
                all_carry_points.append(gt_carry)

                # Store lift from GT
                if cp_idx == 0:
                    ref_lift = gt_carry[0]

            # ========== Plot evaluation trajectories (dashed, lighter) ==========
            for rollout_idx, traj_dict in enumerate(trajectories):
                reach_traj = traj_dict['reach']
                carry_traj = traj_dict['carry']
                all_reach_points.append(reach_traj)
                all_carry_points.append(carry_traj)

                # Lighter appearance for eval trials
                eval_alpha = 0.5

                # ========== Plot 1: Overall Trajectory ==========
                ax1.plot(reach_traj[:, 0], reach_traj[:, 1], reach_traj[:, 2], '--',
                        color=color, alpha=eval_alpha, linewidth=eval_linewidth)
                ax1.plot(carry_traj[:, 0], carry_traj[:, 1], carry_traj[:, 2], '--',
                        color=color, alpha=eval_alpha, linewidth=eval_linewidth)

                # ========== Plot 2: REACH Phase Only ==========
                ax2.plot(reach_traj[:, 0], reach_traj[:, 1], reach_traj[:, 2], '--',
                        color=color, alpha=eval_alpha, linewidth=eval_linewidth)

                # ========== Plot 3: CARRY Phase Only ==========
                ax3.plot(carry_traj[:, 0], carry_traj[:, 1], carry_traj[:, 2], '--',
                        color=color, alpha=eval_alpha, linewidth=eval_linewidth)

            # Plot target markers (smaller)
            ax1.scatter(*pregrasp, color=color, marker='*', s=80, edgecolors='black', linewidths=0.5, zorder=5)
            ax1.scatter(*prerelease, color=color, marker='^', s=60, edgecolors='black', linewidths=0.5, zorder=5)
            ax2.scatter(*pregrasp, color=color, marker='*', s=80, edgecolors='black', linewidths=0.5, zorder=5)
            ax3.scatter(*prerelease, color=color, marker='^', s=60, edgecolors='black', linewidths=0.5, zorder=5)

        # Add waypoint markers (smaller)
        marker_size = 100
        if ref_home is not None:
            ax1.scatter(*ref_home, c='green', marker='o', s=marker_size, label='Home', zorder=10, edgecolors='black', linewidths=1)
            ax2.scatter(*ref_home, c='green', marker='o', s=marker_size, label='Home (start)', zorder=10, edgecolors='black', linewidths=1)
        if ref_pregrasp is not None:
            ax1.scatter(*ref_pregrasp, c='orange', marker='*', s=marker_size, label='Pregrasp', zorder=10, edgecolors='black', linewidths=1)
            ax2.scatter(*ref_pregrasp, c='orange', marker='*', s=marker_size, label='Pregrasp (end)', zorder=10, edgecolors='black', linewidths=1)
        if ref_lift is not None:
            ax1.scatter(*ref_lift, c='purple', marker='^', s=marker_size, label='Lift', zorder=10, edgecolors='black', linewidths=1)
            ax3.scatter(*ref_lift, c='purple', marker='^', s=marker_size, label='Lift (start)', zorder=10, edgecolors='black', linewidths=1)
        if ref_prerelease is not None:
            ax1.scatter(*ref_prerelease, c='cyan', marker='s', s=marker_size, label='Prerelease', zorder=10, edgecolors='black', linewidths=1)
            ax3.scatter(*ref_prerelease, c='cyan', marker='s', s=marker_size, label='Prerelease (end)', zorder=10, edgecolors='black', linewidths=1)

        # Format axes
        reach_mean = np.mean([r["mean_reach_error"] for r in all_results])
        carry_mean = np.mean([r["mean_carry_error"] for r in all_results])
        task_rate = np.mean([r["task_rate"] for r in all_results])

        ax1.set_xlabel('X (m)')
        ax1.set_ylabel('Y (m)')
        ax1.set_zlabel('Z (m)')
        ax1.set_title(f'Overall Trajectory\n({len(all_results)} CPs, Task: {task_rate*100:.0f}%)')
        ax1.legend(fontsize=7, loc='upper left')
        ax1.view_init(elev=20, azim=45)

        ax2.set_xlabel('X (m)')
        ax2.set_ylabel('Y (m)')
        ax2.set_zlabel('Z (m)')
        ax2.set_title(f'REACH Phase\n(Home → Pregrasp, Error: {reach_mean*100:.1f}cm)\nSolid=GT, Dashed=Eval')
        ax2.legend(fontsize=8, loc='upper left')
        ax2.view_init(elev=20, azim=45)

        ax3.set_xlabel('X (m)')
        ax3.set_ylabel('Y (m)')
        ax3.set_zlabel('Z (m)')
        ax3.set_title(f'CARRY Phase\n(Lift → Prerelease, Error: {carry_mean*100:.1f}cm)\nSolid=GT, Dashed=Eval')
        ax3.legend(fontsize=8, loc='upper left')
        ax3.view_init(elev=20, azim=45)

        # Set equal aspect ratio for overall plot
        for ax in [ax1]:
            x_limits = ax.get_xlim3d()
            y_limits = ax.get_ylim3d()
            z_limits = ax.get_zlim3d()
            x_range = x_limits[1] - x_limits[0]
            y_range = y_limits[1] - y_limits[0]
            z_range = z_limits[1] - z_limits[0]
            max_range = max(x_range, y_range, z_range)
            x_mid = (x_limits[0] + x_limits[1]) / 2
            y_mid = (y_limits[0] + y_limits[1]) / 2
            z_mid = (z_limits[0] + z_limits[1]) / 2
            ax.set_xlim3d([x_mid - max_range/2, x_mid + max_range/2])
            ax.set_ylim3d([y_mid - max_range/2, y_mid + max_range/2])
            ax.set_zlim3d([z_mid - max_range/2, z_mid + max_range/2])

        # Zoom in on REACH subplot - fit tightly to trajectory data with small padding
        if all_reach_points:
            all_reach = np.vstack(all_reach_points)
            padding = 0.02  # 2cm padding
            ax2.set_xlim3d([all_reach[:, 0].min() - padding, all_reach[:, 0].max() + padding])
            ax2.set_ylim3d([all_reach[:, 1].min() - padding, all_reach[:, 1].max() + padding])
            ax2.set_zlim3d([all_reach[:, 2].min() - padding, all_reach[:, 2].max() + padding])

        # Zoom in on CARRY subplot - fit tightly to trajectory data with small padding
        if all_carry_points:
            all_carry = np.vstack(all_carry_points)
            padding = 0.02  # 2cm padding
            ax3.set_xlim3d([all_carry[:, 0].min() - padding, all_carry[:, 0].max() + padding])
            ax3.set_ylim3d([all_carry[:, 1].min() - padding, all_carry[:, 1].max() + padding])
            ax3.set_zlim3d([all_carry[:, 2].min() - padding, all_carry[:, 2].max() + padding])

        plt.tight_layout()
        plot_path = self.logdir / 'ee_trajectories_3d.png'
        plt.savefig(plot_path, dpi=150, bbox_inches='tight')
        log.info(f"Saved trajectory plot to {plot_path}")
        plt.close()

        # Plot error distribution for both REACH and CARRY
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))

        # REACH errors
        reach_errors = [e for r in all_results for e in r['reach_errors']]
        axes[0].hist(reach_errors, bins=20, edgecolor='black', alpha=0.7, color='steelblue')
        axes[0].axvline(np.mean(reach_errors), color='red', linestyle='--', linewidth=2,
                        label=f'Mean: {np.mean(reach_errors):.4f}m')
        axes[0].set_xlabel('EE Error (m)', fontsize=12)
        axes[0].set_ylabel('Count', fontsize=12)
        axes[0].set_title('REACH Error Distribution', fontsize=14, fontweight='bold')
        axes[0].legend(fontsize=10)
        axes[0].grid(True, alpha=0.3)

        # CARRY errors
        carry_errors = [e for r in all_results for e in r['carry_errors']]
        axes[1].hist(carry_errors, bins=20, edgecolor='black', alpha=0.7, color='darkorange')
        axes[1].axvline(np.mean(carry_errors), color='red', linestyle='--', linewidth=2,
                        label=f'Mean: {np.mean(carry_errors):.4f}m')
        axes[1].set_xlabel('EE Error (m)', fontsize=12)
        axes[1].set_ylabel('Count', fontsize=12)
        axes[1].set_title('CARRY Error Distribution', fontsize=14, fontweight='bold')
        axes[1].legend(fontsize=10)
        axes[1].grid(True, alpha=0.3)

        plt.tight_layout()
        hist_path = self.logdir / 'error_distribution.png'
        plt.savefig(hist_path, dpi=150, bbox_inches='tight')
        log.info(f"Saved error histogram to {hist_path}")
        plt.close()

        # Plot per-CP error bar chart (both phases)
        fig, ax = plt.subplots(figsize=(14, 6))
        x = np.arange(len(all_results))
        width = 0.35

        reach_means = [r['mean_reach_error'] for r in all_results]
        reach_stds = [r['std_reach_error'] for r in all_results]
        carry_means = [r['mean_carry_error'] for r in all_results]
        carry_stds = [r['std_carry_error'] for r in all_results]

        bars1 = ax.bar(x - width/2, reach_means, width, yerr=reach_stds, capsize=3,
                      label='REACH', color='steelblue', edgecolor='black', alpha=0.8)
        bars2 = ax.bar(x + width/2, carry_means, width, yerr=carry_stds, capsize=3,
                      label='CARRY', color='darkorange', edgecolor='black', alpha=0.8)

        cp_labels = [f'R{r["reach_traj_idx"]}/C{r["carry_traj_idx"]}' for r in all_results]
        ax.set_xticks(x)
        ax.set_xticklabels(cp_labels, rotation=45, ha='right')
        ax.set_xlabel('Control Point Pair', fontsize=12)
        ax.set_ylabel('EE Error (m)', fontsize=12)
        ax.set_title('EE Error per Control Point Pair', fontsize=14, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        bar_path = self.logdir / 'error_per_cp.png'
        plt.savefig(bar_path, dpi=150, bbox_inches='tight')
        log.info(f"Saved per-CP error bar chart to {bar_path}")
        plt.close()
