"""
Evaluate policies on close_drawer task with proper EE trajectory visualization.

Supports:
- PDP (with mode-colored trajectories, 8 colors for 8 modes)
- DP, BC, BC-GMM, IBC (single color trajectories)

Plots saved to: {logdir}/render/ee_trajectories_3d.png
Title includes: Success Rate (mean ± std)
"""

import os
import sys
import numpy as np
import torch
import logging
import pickle
from pathlib import Path
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

log = logging.getLogger(__name__)

# Import wall collision utilities
try:
    from RLBench_close_drawer.make_dataset.wall_collision import (
        WALL_STYLES,
        DEFAULT_WALL_CONFIG,
    )
    from RLBench_close_drawer.make_dataset.visualize_wall_trajectories import (
        check_ee_trajectory_wall_collision,
    )
    WALL_COLLISION_AVAILABLE = True
except ImportError:
    WALL_COLLISION_AVAILABLE = False
    WALL_STYLES = {}
    DEFAULT_WALL_CONFIG = {}
    log.warning("Wall collision module not available. Wall collision checking disabled.")

# Color palette for 8 modes - same as RLBench_close_drawer/plots/create_trajectory_plots.py
MODE_COLORS = [
    '#e41a1c',  # Red
    '#377eb8',  # Blue
    '#4daf4a',  # Green
    '#984ea3',  # Purple
    '#ff7f00',  # Orange
    '#ffff33',  # Yellow
    '#a65628',  # Brown
    '#f781bf',  # Pink
]

# Add paths for imports
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.join(dppo_root, 'RLBench_close_drawer', 'encoder'))
sys.path.append(dppo_root)

from util.timer import Timer
from agent.eval.eval_agent import EvalAgent


def check_handle_contact_from_trajectory(ee_trajectory, handle_pos, threshold=0.05):
    """
    Check if EE trajectory reached close enough to the handle position.

    In the dataset, the robot first reaches the handle (end of reach phase),
    then pushes to close. We check if the EE trajectory got within threshold
    of the handle at any point, which confirms the robot followed the correct
    approach pattern.

    Args:
        ee_trajectory: np.array of shape (N, 3), EE positions along trajectory
        handle_pos: np.array of shape (3,), handle position (open drawer position)
        threshold: float, distance threshold for contact (default 5cm)

    Returns:
        contacted: bool, True if EE got within threshold of handle at any point
        min_dist: float, minimum distance to handle along trajectory
    """
    # If no trajectory, no contact
    if ee_trajectory is None or len(ee_trajectory) == 0:
        return False, float('inf')
    # If handle_pos is None (non-close_drawer task), skip check
    if handle_pos is None:
        return True, 0.0  # Assume contact if we can't check

    # Convert to numpy arrays safely
    try:
        ee_trajectory = np.asarray(ee_trajectory, dtype=float)
        handle_pos = np.asarray(handle_pos, dtype=float).flatten()
        if ee_trajectory.ndim != 2 or ee_trajectory.shape[1] != 3:
            return False, float('inf')
        if handle_pos.shape != (3,):
            return False, float('inf')

        # Compute distance from each trajectory point to handle
        distances = np.linalg.norm(ee_trajectory - handle_pos, axis=1)
        min_dist = np.min(distances)
        contacted = min_dist <= threshold
        return contacted, min_dist
    except (TypeError, ValueError):
        return False, float('inf')


class EvalCloseDrawerAgent(EvalAgent):
    """
    Base evaluation agent for close_drawer task.
    Collects EE trajectories and plots them with success rate in title.

    Success criteria (all 3 must be satisfied):
    1. No wall collision (if wall is enabled)
    2. EE touched handle (got within handle_contact_threshold)
    3. Drawer closed (task reward == 1.0)
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.model_name = cfg.get('model_name', 'Policy')
        self.handle_contact_threshold = cfg.get('handle_contact_threshold', 0.05)  # 5cm default

        # x_T file for DP inversion (optimized initial noise)
        self.x_T_file = cfg.get('x_T_file', None)
        self.x_T = None
        if self.x_T_file is not None and os.path.exists(self.x_T_file):
            self.x_T = torch.from_numpy(np.load(self.x_T_file)).float().to(self.device)
            log.info(f"Loaded optimized x_T from {self.x_T_file}, shape: {self.x_T.shape}")

        # Selected component file for BC-GMM posterior selection
        self.selected_component_file = cfg.get('selected_component_file', None)
        self.selected_component = None
        if self.selected_component_file is not None and os.path.exists(self.selected_component_file):
            component_data = np.load(self.selected_component_file, allow_pickle=True).item()
            self.selected_component = component_data['selected_component']
            log.info(f"Loaded selected component k*={self.selected_component} from {self.selected_component_file}")

        # Demo trajectory file for IBC demo-initialized inference
        self.demo_trajectory_file = cfg.get('demo_trajectory_file', None)
        self.demo_trajectory = None
        self.demo_noise_std = cfg.get('demo_noise_std', 0.1)
        if self.demo_trajectory_file is not None and os.path.exists(self.demo_trajectory_file):
            self.demo_trajectory = torch.from_numpy(np.load(self.demo_trajectory_file)).float().to(self.device)
            log.info(f"Loaded demo trajectory from {self.demo_trajectory_file}, shape: {self.demo_trajectory.shape}")
            log.info(f"Demo noise std: {self.demo_noise_std}")

        # Wall collision checking
        self.wall_style = cfg.get('wall_style', -1)  # -1 = no wall
        self.wall_config = None
        if self.wall_style >= 0 and WALL_COLLISION_AVAILABLE:
            # Load wall config: style 0 = default, style 1/2/3 = predefined
            if self.wall_style == 0:
                # Style 0: use default wall config
                self.wall_config = DEFAULT_WALL_CONFIG.copy()
                if self.wall_config.get("opening") is not None:
                    self.wall_config["opening"] = self.wall_config["opening"].copy()
                log.info(f"Wall collision enabled: style 0 (default)")
            elif self.wall_style in WALL_STYLES:
                self.wall_config = WALL_STYLES[self.wall_style].copy()
                # Deep copy opening if exists
                if self.wall_config.get("opening") is not None:
                    self.wall_config["opening"] = self.wall_config["opening"].copy()
                log.info(f"Wall collision enabled: style {self.wall_style}")
            else:
                log.warning(f"Wall style {self.wall_style} not found. Available: 0 (default), {list(WALL_STYLES.keys())}")

            if self.wall_config is not None:
                log.info(f"  Wall Y={self.wall_config['wall_y']:.2f}, "
                        f"X=[{self.wall_config['wall_min_x']:.2f}, {self.wall_config['wall_max_x']:.2f}], "
                        f"Z=[{self.wall_config['wall_min_z']:.2f}, {self.wall_config['wall_max_z']:.2f}]")
                if self.wall_config.get("opening"):
                    log.info(f"  Opening: X=[{self.wall_config['opening']['min_x']:.2f}, "
                            f"{self.wall_config['opening']['max_x']:.2f}], "
                            f"Z=[{self.wall_config['opening']['min_z']:.2f}, "
                            f"{self.wall_config['opening']['max_z']:.2f}]")

                # Modify save paths to include wall style
                wall_suffix = f"_wall{self.wall_style}"
                self.logdir = self.logdir + wall_suffix
                self.render_dir = os.path.join(self.logdir, "render")
                self.result_path = os.path.join(self.logdir, "result.npz")
                os.makedirs(self.render_dir, exist_ok=True)
                log.info(f"Wall style {self.wall_style} enabled, saving to: {self.logdir}")
        elif self.wall_style >= 0 and not WALL_COLLISION_AVAILABLE:
            log.warning("Wall collision requested but module not available.")

    def run(self):
        """Run evaluation and collect trajectories."""
        timer = Timer()

        # Prepare video paths
        options_venv = [{} for _ in range(self.n_envs)]
        if self.render_video:
            for env_ind in range(self.n_render):
                options_venv[env_ind]["video_path"] = os.path.join(
                    self.render_dir, f"eval_trial-{env_ind}.mp4"
                )

        # Reset env
        self.model.eval()
        firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
        prev_obs_venv = self.reset_env_all(options_venv=options_venv)
        firsts_trajs[0] = 1
        reward_trajs = np.zeros((self.n_steps, self.n_envs))

        # Track EE trajectories
        ee_trajectories = []  # Full trajectories for contact checking
        ee_trajectories_plot = []  # Trajectories truncated at success for plotting
        trajectory_success = []
        trajectory_wall_collision = []  # Track wall collisions separately
        trajectory_collision_idx = []  # Index where wall collision occurred (None if no collision)
        trajectory_handle_contact = []  # Track handle contact separately
        trajectory_drawer_closed = []  # Track if drawer was closed (task reward)
        trajectory_min_handle_dist = []  # Track minimum distance to handle
        trajectory_mode_ids = []  # For PDP
        episode_steps_list = []  # Track steps per episode
        actions_executed_list = []  # Track actions executed per episode (out of 87)
        current_ee_trajectory = [[] for _ in range(self.n_envs)]
        episode_has_initial_pos = [False] * self.n_envs
        current_handle_pos = [None] * self.n_envs  # Track handle position per env
        current_success_step = [None] * self.n_envs  # Track step when success occurred

        # Episode tracking
        episodes_completed = 0
        episodes_succeeded = 0
        current_episode_steps = np.zeros(self.n_envs, dtype=int)
        max_steps_per_episode = self.max_episode_steps // self.act_steps

        n_eval_episodes = self.n_eval_episodes if self.n_eval_episodes else self.n_envs * 10

        # Progress bars
        episode_pbar = tqdm(total=n_eval_episodes, desc="Episodes", position=0, leave=True, unit="ep")
        step_pbar = tqdm(total=max_steps_per_episode, desc="Current episode", position=1, leave=False, unit="step")

        step = 0
        current_mode_id = 0  # For tracking which mode we're evaluating

        while episodes_completed < n_eval_episodes and step < self.n_steps:
            step += 1

            # Select action
            with torch.no_grad():
                cond = {
                    "state": torch.from_numpy(prev_obs_venv["state"])
                    .float()
                    .to(self.device)
                }
                # If x_T is provided (DP inversion), use it as initial noise
                if self.x_T is not None:
                    # Expand x_T for batch size
                    batch_size = cond["state"].shape[0]
                    cond["noise_action"] = self.x_T.expand(batch_size, -1, -1)
                # If selected_component is provided (BC-GMM posterior selection), use it
                if self.selected_component is not None:
                    samples = self.model(cond=cond, deterministic=True, fixed_component=self.selected_component)
                # If demo_trajectory is provided (IBC demo-initialized inference), use it
                elif self.demo_trajectory is not None:
                    samples = self.model(cond=cond, deterministic=True,
                                        demo_trajectory=self.demo_trajectory,
                                        demo_noise_std=self.demo_noise_std)
                else:
                    samples = self.model(cond=cond, deterministic=True)
                output_venv = samples.trajectories.cpu().numpy()
            action_venv = output_venv[:, :self.act_steps]

            # Step environment
            obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = (
                self.venv.step(action_venv)
            )
            reward_trajs[step] = reward_venv
            firsts_trajs[step + 1] = terminated_venv | truncated_venv

            # Collect EE positions and handle position
            for env_idx in range(self.n_envs):
                if not episode_has_initial_pos[env_idx] and "initial_ee_pos" in info_venv[env_idx]:
                    initial_pos = info_venv[env_idx]["initial_ee_pos"]
                    if initial_pos is not None:
                        initial_pos = np.array(initial_pos).flatten()
                        if initial_pos.shape[0] == 3:
                            current_ee_trajectory[env_idx].append(initial_pos)
                            episode_has_initial_pos[env_idx] = True

                if "all_primitive_ee_positions" in info_venv[env_idx]:
                    all_positions = info_venv[env_idx]["all_primitive_ee_positions"]
                    for ee_pos in all_positions:
                        if ee_pos is not None and hasattr(ee_pos, '__len__'):
                            ee_pos = np.array(ee_pos).flatten()
                            if ee_pos.shape[0] == 3:
                                current_ee_trajectory[env_idx].append(ee_pos)

                # Track handle position for contact checking
                if "handle_pos" in info_venv[env_idx] and info_venv[env_idx]["handle_pos"] is not None:
                    current_handle_pos[env_idx] = np.array(info_venv[env_idx]["handle_pos"])

                # Track success step (for plotting - stop at success)
                if "success_step" in info_venv[env_idx] and info_venv[env_idx]["success_step"] is not None:
                    if current_success_step[env_idx] is None:
                        current_success_step[env_idx] = len(current_ee_trajectory[env_idx])

            current_episode_steps += 1
            avg_steps = int(np.mean(current_episode_steps))
            step_pbar.n = avg_steps
            step_pbar.refresh()

            # Check episode completion
            done_indices = np.where(terminated_venv | truncated_venv)[0]
            for env_idx in done_indices:
                episodes_completed += 1
                # Check success from info dict (handles collect_full_trajectory mode)
                # In collect_full_trajectory mode, task success is tracked in info['success']
                # even though terminated_venv is False (we don't break on success)
                if 'success' in info_venv[env_idx]:
                    drawer_closed = info_venv[env_idx]['success']
                else:
                    drawer_closed = terminated_venv[env_idx] and not truncated_venv[env_idx]

                if len(current_ee_trajectory[env_idx]) > 0:
                    traj_array = np.array(current_ee_trajectory[env_idx])

                    # Check wall collision on the collected trajectory
                    wall_collision = False
                    collision_idx = None
                    if self.wall_config is not None and WALL_COLLISION_AVAILABLE:
                        wall_collision, collision_idx = check_ee_trajectory_wall_collision(
                            traj_array, self.wall_config
                        )

                    # Check if EE trajectory reached near the handle (proper approach)
                    handle_contact, contact_dist = check_handle_contact_from_trajectory(
                        traj_array, current_handle_pos[env_idx], self.handle_contact_threshold
                    )

                    # Success criteria: all 3 must be satisfied
                    # 1. No wall collision
                    # 2. Handle contact (EE was close to handle when drawer moved)
                    # 3. Drawer closed (task reward)
                    success = (not wall_collision) and handle_contact and drawer_closed

                    if success:
                        episodes_succeeded += 1

                    ee_trajectories.append(traj_array)
                    # Create truncated trajectory for plotting (up to success)
                    if current_success_step[env_idx] is not None:
                        traj_plot = traj_array[:current_success_step[env_idx]]
                    else:
                        traj_plot = traj_array  # No success, use full trajectory
                    ee_trajectories_plot.append(traj_plot)
                    trajectory_success.append(success)
                    trajectory_wall_collision.append(wall_collision)
                    trajectory_collision_idx.append(collision_idx)
                    trajectory_handle_contact.append(handle_contact)
                    trajectory_drawer_closed.append(drawer_closed)
                    trajectory_min_handle_dist.append(contact_dist)
                    trajectory_mode_ids.append(self._get_current_mode_id(episodes_completed - 1))
                    episode_steps_list.append(current_episode_steps[env_idx])
                    # Track actions executed (from environment info)
                    if 'actions_executed' in info_venv[env_idx]:
                        actions_executed_list.append(info_venv[env_idx]['actions_executed'])
                    current_ee_trajectory[env_idx] = []
                    episode_has_initial_pos[env_idx] = False
                    current_handle_pos[env_idx] = None  # Reset for next episode
                    current_success_step[env_idx] = None  # Reset for next episode

                episode_pbar.update(1)
                current_episode_steps[env_idx] = 0

                if self.n_envs == 1:
                    step_pbar.reset()

                if episodes_completed >= n_eval_episodes:
                    break

            if episodes_completed >= n_eval_episodes:
                break

            prev_obs_venv = obs_venv

        episode_pbar.close()
        step_pbar.close()

        # Compute success rate statistics
        success_rate = episodes_succeeded / episodes_completed if episodes_completed > 0 else 0

        # Compute per-episode success for std calculation
        success_array = np.array(trajectory_success, dtype=float)
        success_mean = np.mean(success_array)
        success_std = np.std(success_array)

        log.info(f"eval: {episodes_completed} episodes | success rate {success_rate:.4f} ({episodes_succeeded}/{episodes_completed})")
        log.info(f"Success rate: {success_mean*100:.1f}% ± {success_std*100:.1f}%")

        # Log breakdown of success criteria
        n_handle_contact = sum(trajectory_handle_contact)
        n_drawer_closed = sum(trajectory_drawer_closed)
        n_no_wall_collision = sum(1 for w in trajectory_wall_collision if not w)
        log.info(f"Success criteria breakdown:")
        log.info(f"  - No wall collision: {n_no_wall_collision}/{episodes_completed}")
        log.info(f"  - Handle contact at push (threshold={self.handle_contact_threshold*100:.0f}cm): {n_handle_contact}/{episodes_completed}")
        log.info(f"  - Drawer closed: {n_drawer_closed}/{episodes_completed}")

        # Log handle distance statistics (distance at moment drawer started moving)
        if len(trajectory_min_handle_dist) > 0:
            valid_dists = [d for d in trajectory_min_handle_dist if d != float('inf')]
            if valid_dists:
                dist_array = np.array(valid_dists)
                log.info(f"EE-handle distance at push: min={dist_array.min()*100:.1f}cm, max={dist_array.max()*100:.1f}cm, mean={dist_array.mean()*100:.1f}cm")

        # Compute episode step statistics
        # Note: "episode steps" = number of policy calls (with act_steps=87, this is typically 1)
        # "trajectory points" = number of EE positions collected (more meaningful metric)
        if len(episode_steps_list) > 0:
            steps_array = np.array(episode_steps_list)
            steps_min = int(np.min(steps_array))
            steps_max = int(np.max(steps_array))
            steps_mean = np.mean(steps_array)
            steps_std = np.std(steps_array)
            log.info(f"Policy calls per episode: min={steps_min}, max={steps_max}, mean={steps_mean:.1f}, std={steps_std:.1f}")
        else:
            steps_min = steps_max = 0
            steps_mean = steps_std = 0.0

        # Compute actions executed statistics (out of 87 total actions)
        if len(actions_executed_list) > 0:
            actions_array = np.array(actions_executed_list)
            actions_min = int(np.min(actions_array))
            actions_max = int(np.max(actions_array))
            actions_mean = np.mean(actions_array)
            actions_std = np.std(actions_array)
            log.info(f"Actions executed: min={actions_min}, max={actions_max}, mean={actions_mean:.1f}, std={actions_std:.1f}")
        else:
            actions_min = actions_max = 0
            actions_mean = actions_std = 0.0

        # Save results
        np.savez(
            self.result_path,
            num_episode=episodes_completed,
            eval_success_rate=success_rate,
            success_mean=success_mean,
            success_std=success_std,
            episode_steps=episode_steps_list,
            steps_min=steps_min,
            steps_max=steps_max,
            steps_mean=steps_mean,
            steps_std=steps_std,
            actions_executed=actions_executed_list,
            actions_min=actions_min,
            actions_max=actions_max,
            actions_mean=actions_mean,
            actions_std=actions_std,
            # Success criteria breakdown
            n_handle_contact=n_handle_contact,
            n_drawer_closed=n_drawer_closed,
            n_no_wall_collision=n_no_wall_collision,
        )

        # Plot trajectories (use truncated trajectories for plotting)
        if len(ee_trajectories_plot) > 0:
            self._plot_trajectories(ee_trajectories_plot, trajectory_success, trajectory_mode_ids,
                                   success_mean, success_std, trajectory_collision_idx)

            # Save trajectories (save both full and truncated for analysis)
            traj_path = self.result_path.replace('.npz', '_trajectories.pkl')
            with open(traj_path, 'wb') as f:
                pickle.dump({
                    'ee_trajectories': ee_trajectories,  # Full trajectories for contact checking
                    'ee_trajectories_plot': ee_trajectories_plot,  # Truncated for plotting
                    'success': trajectory_success,
                    'mode_ids': trajectory_mode_ids,
                    'collision_idx': trajectory_collision_idx,
                    'handle_contact': trajectory_handle_contact,
                    'drawer_closed': trajectory_drawer_closed,
                }, f)
            log.info(f"Saved trajectories to {traj_path}")

    def _get_current_mode_id(self, episode_idx):
        """Get mode ID for current episode. Override in PDP eval agent."""
        return 0  # Single mode for non-PDP methods

    def _plot_trajectories(self, ee_trajectories, trajectory_success, trajectory_mode_ids,
                          success_mean, success_std, trajectory_collision_idx=None):
        """
        Plot EE trajectories with success rate in title.
        Uses blue color for non-PDP methods.
        Matches style of RLBench_close_drawer/plots/create_trajectory_plots.py

        If wall collision occurred, trajectory is truncated at collision point.
        """
        if trajectory_collision_idx is None:
            trajectory_collision_idx = [None] * len(ee_trajectories)

        fig = plt.figure(figsize=(12, 10))
        ax = fig.add_subplot(111, projection='3d')

        # Collect trajectory bounds for proper axis scaling
        all_x, all_y, all_z = [], [], []

        # Use blue color for non-PDP methods (single mode)
        color = MODE_COLORS[1]  # Blue

        for idx, (trajectory, success, collision_idx) in enumerate(
                zip(ee_trajectories, trajectory_success, trajectory_collision_idx)):
            if len(trajectory) > 0:
                # Truncate trajectory at collision point if wall collision occurred
                if collision_idx is not None:
                    trajectory = trajectory[:collision_idx + 1]  # Include collision point

                x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
                all_x.extend(x)
                all_y.extend(y)
                all_z.extend(z)
                alpha = 0.7 if success else 0.4
                ax.plot(x, y, z, color=color, linewidth=1.2, alpha=alpha)

                # Mark start point
                ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2],
                          color=color, s=20, marker='o', alpha=0.8)
                # Mark end point - use 'X' for collision, 'x' for normal end
                if collision_idx is not None:
                    # Red X marker for collision point
                    ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2],
                              color='red', s=50, marker='X', alpha=1.0)
                else:
                    ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2],
                              color=color, s=20, marker='x', alpha=0.8)

        # Draw wall if enabled
        if self.wall_config is not None and len(all_x) > 0:
            self._draw_wall_on_3d_plot(ax, all_x, all_y, all_z)

        from matplotlib.lines import Line2D
        from matplotlib.patches import Patch
        legend_handles = [Line2D([0], [0], color=color, linewidth=2, label='Trajectory')]
        if self.wall_config is not None:
            wall_label = 'Wall (with opening)' if self.wall_config.get("opening") else 'Wall'
            legend_handles.append(Patch(facecolor='red', alpha=0.3, edgecolor='darkred', label=wall_label))
            # Add collision marker to legend if any collisions occurred
            if any(c is not None for c in trajectory_collision_idx):
                legend_handles.append(Line2D([0], [0], marker='X', color='red', linestyle='None',
                                            markersize=10, label='Wall Collision'))
        ax.legend(handles=legend_handles, loc='upper right', fontsize=9)

        ax.set_xlabel('X (m)', fontsize=11)
        ax.set_ylabel('Y (m)', fontsize=11)
        ax.set_zlabel('Z (m)', fontsize=11)

        # Build title with wall info
        title = f'{self.model_name} - Close Drawer\n'
        title += f'Success Rate: {success_mean*100:.1f}% ± {success_std*100:.1f}% '
        title += f'({sum(trajectory_success)}/{len(trajectory_success)} trajectories)'
        if self.wall_config is not None:
            title += f'\nWall Style {self.wall_style} (Y={self.wall_config["wall_y"]:.2f})'
        ax.set_title(title, fontsize=14)

        # Set equal aspect ratio for 3D plot (matching create_trajectory_plots.py)
        if len(ee_trajectories) > 0:
            all_points = np.vstack([t for t in ee_trajectories if len(t) > 0])
            max_range = np.array([
                all_points[:, 0].max() - all_points[:, 0].min(),
                all_points[:, 1].max() - all_points[:, 1].min(),
                all_points[:, 2].max() - all_points[:, 2].min()
            ]).max() / 2.0

            mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
            mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
            mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5

            ax.set_xlim(mid_x - max_range, mid_x + max_range)
            ax.set_ylim(mid_y - max_range, mid_y + max_range)
            ax.set_zlim(mid_z - max_range, mid_z + max_range)

        # View angle to match camera frame (same as visualize_wall_trajectories.py)
        ax.view_init(elev=30, azim=-135)

        plt.tight_layout()
        plot_path = os.path.join(self.render_dir, 'ee_trajectories_3d.png')
        plt.savefig(plot_path, dpi=200, bbox_inches='tight')
        log.info(f"Saved EE trajectory plot to {plot_path}")
        plt.close()

    def _draw_wall_on_3d_plot(self, ax, all_x, all_y, all_z):
        """
        Draw wall on 3D plot at its actual size (no clipping).
        """
        wall_config = self.wall_config

        # Use actual wall bounds (no clipping to trajectory area)
        wall_y = wall_config["wall_y"]
        wall_min_x = wall_config["wall_min_x"]
        wall_max_x = wall_config["wall_max_x"]
        wall_min_z = wall_config["wall_min_z"]
        wall_max_z = wall_config["wall_max_z"]

        opening = wall_config.get("opening", None)

        if opening is not None:
            # Clip opening to wall bounds (for display)
            op_min_x = max(opening["min_x"], wall_min_x)
            op_max_x = min(opening["max_x"], wall_max_x)
            op_min_z = max(opening["min_z"], wall_min_z)
            op_max_z = min(opening["max_z"], wall_max_z)

            # Split wall into 4 rectangles around the opening
            wall_pieces = []

            # Bottom piece (below opening)
            if wall_min_z < op_min_z:
                wall_pieces.append([
                    [wall_min_x, wall_y, wall_min_z],
                    [wall_max_x, wall_y, wall_min_z],
                    [wall_max_x, wall_y, op_min_z],
                    [wall_min_x, wall_y, op_min_z],
                ])

            # Top piece (above opening)
            if op_max_z < wall_max_z:
                wall_pieces.append([
                    [wall_min_x, wall_y, op_max_z],
                    [wall_max_x, wall_y, op_max_z],
                    [wall_max_x, wall_y, wall_max_z],
                    [wall_min_x, wall_y, wall_max_z],
                ])

            # Left piece (left of opening, between opening Z bounds)
            if wall_min_x < op_min_x:
                wall_pieces.append([
                    [wall_min_x, wall_y, op_min_z],
                    [op_min_x, wall_y, op_min_z],
                    [op_min_x, wall_y, op_max_z],
                    [wall_min_x, wall_y, op_max_z],
                ])

            # Right piece (right of opening, between opening Z bounds)
            if op_max_x < wall_max_x:
                wall_pieces.append([
                    [op_max_x, wall_y, op_min_z],
                    [wall_max_x, wall_y, op_min_z],
                    [wall_max_x, wall_y, op_max_z],
                    [op_max_x, wall_y, op_max_z],
                ])

            # Draw each wall piece
            for piece in wall_pieces:
                piece_poly = Poly3DCollection([piece], alpha=0.3, facecolor='red',
                                              edgecolor='darkred', linewidth=1)
                ax.add_collection3d(piece_poly)
        else:
            # No opening - draw solid wall
            wall_vertices = [
                [wall_min_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, wall_max_z],
                [wall_min_x, wall_y, wall_max_z],
            ]
            wall_poly = Poly3DCollection([wall_vertices], alpha=0.3, facecolor='red',
                                          edgecolor='darkred', linewidth=2)
            ax.add_collection3d(wall_poly)


class EvalCloseDrawerPDPAgent(EvalCloseDrawerAgent):
    """
    Evaluation agent for PDP on close_drawer task.
    Uses 8 different colors for 8 modes.
    """

    def __init__(self, cfg):
        self.z_list = cfg.get('z_list', [])
        self.z_file = cfg.get('z_file', None)  # Path to .npy file with z embeddings
        self.encoder_checkpoint = cfg.get('encoder_checkpoint', None)
        self.dataset_path = cfg.get('dataset_path', None)
        self.n_noise_samples = cfg.get('n_noise_samples', 5)
        self.n_modes = 8
        self.demo_trajectory_path = cfg.get('demo_trajectory_path', None)  # Path to target demo for DTW

        super().__init__(cfg)

        # Load z embeddings: priority is z_file > z_list > auto-generate
        if self.z_file is not None and os.path.exists(self.z_file):
            # Load from .npy file
            z_array = np.load(self.z_file)
            log.info(f"Loaded z embeddings from {self.z_file}, shape: {z_array.shape}")
            self.z_list = [torch.FloatTensor(z_array[i]).to(self.device) for i in range(z_array.shape[0])]
            self.n_modes = len(self.z_list)
        elif len(self.z_list) > 0:
            self.z_list = [torch.FloatTensor(z).to(self.device) for z in self.z_list]
            self.n_modes = len(self.z_list)
        else:
            # Auto-generate z embeddings from dataset
            if self.encoder_checkpoint is None or self.dataset_path is None:
                raise ValueError("Must provide z_file, z_list, or (encoder_checkpoint and dataset_path) for z generation")
            self.z_list = self._generate_z_embeddings()
            self.n_modes = len(self.z_list)

        log.info(f"Evaluating {len(self.z_list)} modes x {self.n_noise_samples} noise samples = "
                f"{len(self.z_list) * self.n_noise_samples} total episodes")

    def _generate_z_embeddings(self):
        """Generate z embeddings from first demo of each mode."""
        from trajectory_encoder import TrajectoryVAE
        from train_close_drawer_encoder import make_trajectory_relative

        log.info(f"Loading encoder from {self.encoder_checkpoint}")
        checkpoint = torch.load(self.encoder_checkpoint, map_location=self.device, weights_only=False)
        config = checkpoint['config']

        encoder = TrajectoryVAE(
            state_dim=config['state_dim'],
            action_dim=config['action_dim'],
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers'],
            num_heads=config['num_heads'],
            latent_dim=config['latent_dim'],
            horizon=config['horizon']
        ).to(self.device)
        encoder.load_state_dict(checkpoint['model_state_dict'])
        encoder.eval()

        # Load dataset
        data = np.load(self.dataset_path)
        all_states = data['states']
        all_actions = data['actions']
        traj_lengths = data['traj_lengths']

        # Load metadata from per-episode files
        dataset_dir = os.path.dirname(self.dataset_path)
        data_root = os.path.dirname(dataset_dir)
        episodes_dir = os.path.join(data_root, 'train', 'episodes')

        # Build metadata list from individual episode files
        num_episodes = len(traj_lengths)
        metadata = []
        for ep_idx in range(num_episodes):
            ep_meta_path = os.path.join(episodes_dir, f'episode{ep_idx}', 'metadata.npy')
            if os.path.exists(ep_meta_path):
                meta = np.load(ep_meta_path, allow_pickle=True).item()
                metadata.append(meta)
            else:
                log.warning(f"Missing metadata for episode {ep_idx}")
                metadata.append({'mode': 0, 'demo_in_mode': ep_idx, 'with_noise': True})

        # Find first demo of each mode
        mode_states = []
        mode_actions = []
        seen_modes = set()

        start_idx = 0
        for episode_idx, traj_len in enumerate(traj_lengths):
            end_idx = start_idx + traj_len
            meta = metadata[episode_idx]

            if meta['demo_in_mode'] == 0 and not meta['with_noise']:
                mode_id = meta['mode']
                if mode_id not in seen_modes:
                    seen_modes.add(mode_id)
                    states = all_states[start_idx:end_idx]
                    actions = all_actions[start_idx:end_idx]
                    mode_states.append(states)
                    mode_actions.append(actions)
                    log.info(f"  Mode {mode_id}: episode {episode_idx}")

            start_idx = end_idx
            if len(mode_states) >= self.n_modes:
                break

        log.info(f"Found {len(mode_states)} modes")

        # Encode to z
        mode_states = torch.FloatTensor(np.array(mode_states)).to(self.device)
        mode_actions = torch.FloatTensor(np.array(mode_actions)).to(self.device)
        mode_states_rel, mode_actions_rel = make_trajectory_relative(mode_states, mode_actions)

        with torch.no_grad():
            z_embeddings = encoder.encode(mode_states_rel, mode_actions_rel)

        for i in range(min(self.n_modes, z_embeddings.shape[0])):
            log.info(f"  z[{i}]: {z_embeddings[i].cpu().numpy()}")

        return [z_embeddings[i] for i in range(z_embeddings.shape[0])]

    def _load_demo_ee_trajectory(self):
        """Load demo EE trajectory from demo_trajectory_path for DTW comparison."""
        if self.demo_trajectory_path is None:
            return None

        # Try loading from ee_trajectory.npy first
        ee_traj_path = os.path.join(self.demo_trajectory_path, 'ee_trajectory.npy')
        if os.path.exists(ee_traj_path):
            demo_ee_traj = np.load(ee_traj_path)
            log.info(f"Loaded demo EE trajectory from {ee_traj_path}, shape: {demo_ee_traj.shape}")
            return demo_ee_traj

        # Fallback: extract from low_dim_obs.pkl
        pkl_path = os.path.join(self.demo_trajectory_path, 'low_dim_obs.pkl')
        if os.path.exists(pkl_path):
            with open(pkl_path, 'rb') as f:
                demo = pickle.load(f)
            # Extract gripper pose (xyz) from each observation
            # State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
            # gripper_pose indices: 15:18 for xyz
            demo_ee_traj = []
            for obs in demo:
                if hasattr(obs, 'gripper_pose') and obs.gripper_pose is not None:
                    demo_ee_traj.append(obs.gripper_pose[:3])  # xyz only
            demo_ee_traj = np.array(demo_ee_traj)
            log.info(f"Extracted demo EE trajectory from {pkl_path}, shape: {demo_ee_traj.shape}")
            return demo_ee_traj

        log.warning(f"Could not load demo trajectory from {self.demo_trajectory_path}")
        return None

    def _compute_dtw_distances(self):
        """Compute DTW distance between successful trajectories and demo trajectory.

        Uses scipy's fastdtw or a symmetric mean distance approach that handles
        trajectories of different lengths (demo: ~88 steps, evaluated: ~551 steps).
        """
        if self.demo_trajectory_path is None:
            log.info("No demo_trajectory_path provided, skipping DTW computation")
            return

        # Load demo EE trajectory
        demo_ee_traj = self._load_demo_ee_trajectory()
        if demo_ee_traj is None:
            log.warning("Could not load demo trajectory, skipping DTW computation")
            return

        # Collect successful trajectories
        successful_trajectories = []
        for i, (traj, success) in enumerate(zip(self.all_ee_trajectories, self.all_success)):
            if success and traj is not None and len(traj) > 0:
                successful_trajectories.append(np.array(traj))

        if len(successful_trajectories) == 0:
            log.info("No successful trajectories to compute DTW for")
            return

        log.info(f"\n{'='*60}")
        log.info(f"DTW Distance Computation")
        log.info(f"{'='*60}")
        log.info(f"Demo trajectory: {demo_ee_traj.shape}")
        log.info(f"Successful trajectories: {len(successful_trajectories)}")

        # Try to import fastdtw for proper DTW computation
        use_fastdtw = False
        try:
            from fastdtw import fastdtw
            from scipy.spatial.distance import euclidean
            use_fastdtw = True
            log.info("Using fastdtw for DTW computation")
        except ImportError:
            log.info("fastdtw not available, using symmetric mean distance")

        # Compute DTW distances
        dtw_distances = []

        for i, traj in enumerate(successful_trajectories):
            if use_fastdtw:
                # Use fastdtw - handles different length sequences properly
                distance, _ = fastdtw(demo_ee_traj, traj, dist=euclidean)
                # Normalize by path length
                dist = distance / max(len(demo_ee_traj), len(traj))
            else:
                # Fallback: symmetric mean distance using scipy cdist
                # This computes average nearest-neighbor distance in both directions
                from scipy.spatial.distance import cdist

                # Compute pairwise distance matrix between demo and trajectory
                cost_matrix = cdist(demo_ee_traj, traj, metric='euclidean')

                # For each point in demo, find nearest point in traj
                dist_demo_to_traj = np.mean(np.min(cost_matrix, axis=1))
                # For each point in traj, find nearest point in demo
                dist_traj_to_demo = np.mean(np.min(cost_matrix, axis=0))
                # Symmetric distance
                dist = (dist_demo_to_traj + dist_traj_to_demo) / 2.0

            dtw_distances.append(dist)

        dtw_distances = np.array(dtw_distances)
        dtw_mean = np.mean(dtw_distances)
        dtw_std = np.std(dtw_distances)

        log.info(f"DTW Distance (successful trajectories vs demo):")
        log.info(f"  Mean ± Std: {dtw_mean:.4f} ± {dtw_std:.4f}")
        log.info(f"  Min: {dtw_distances.min():.4f}, Max: {dtw_distances.max():.4f}")
        log.info(f"{'='*60}")

        # Save DTW results
        dtw_path = self.result_path.replace('.npz', '_dtw.npz')
        np.savez(dtw_path,
                 dtw_distances=dtw_distances,
                 dtw_mean=dtw_mean,
                 dtw_std=dtw_std,
                 n_successful=len(successful_trajectories),
                 demo_trajectory=demo_ee_traj)
        log.info(f"Saved DTW results to {dtw_path}")

    def run(self):
        """Run evaluation for all z embeddings sequentially."""
        log.info(f"Starting PDP evaluation: {len(self.z_list)} modes x {self.n_noise_samples} samples")

        self.all_ee_trajectories = []  # Full trajectories for contact checking
        self.all_ee_trajectories_plot = []  # Truncated trajectories for plotting
        self.all_success = []
        self.all_mode_ids = []
        self.all_episode_steps = []  # Track steps per episode
        self.all_actions_executed = []  # Track actions executed per episode
        self.all_collision_idx = []  # Track wall collision indices
        self.all_handle_contact = []  # Track handle contact
        self.all_drawer_closed = []  # Track drawer closed
        self.all_min_handle_dist = []  # Track min distance to handle
        self.all_episodes_succeeded = 0
        self.all_episodes_completed = 0

        for z_idx, z_embedding in enumerate(self.z_list):
            log.info(f"\n{'='*60}")
            log.info(f"Mode {z_idx + 1}/{len(self.z_list)}")
            log.info(f"{'='*60}")

            # Set current z
            self.model.current_z = z_embedding.unsqueeze(0)
            self.current_mode_idx = z_idx

            # Run evaluation for this mode
            self._run_single_mode()

        # Trigger final video save by doing a reset (without video path)
        if self.render_video:
            self.reset_env_all(options_venv=[{}])

        # Compute overall statistics
        success_array = np.array(self.all_success, dtype=float)
        success_mean = np.mean(success_array)
        success_std = np.std(success_array)

        log.info(f"\n{'='*60}")
        log.info(f"Overall: {self.all_episodes_succeeded}/{self.all_episodes_completed} succeeded")
        log.info(f"Success rate: {success_mean*100:.1f}% ± {success_std*100:.1f}%")

        # Log breakdown of success criteria
        n_handle_contact = sum(self.all_handle_contact)
        n_drawer_closed = sum(self.all_drawer_closed)
        n_no_wall_collision = sum(1 for c in self.all_collision_idx if c is None)
        log.info(f"Success criteria breakdown:")
        log.info(f"  - No wall collision: {n_no_wall_collision}/{self.all_episodes_completed}")
        log.info(f"  - Handle contact at push (threshold={self.handle_contact_threshold*100:.0f}cm): {n_handle_contact}/{self.all_episodes_completed}")
        log.info(f"  - Drawer closed: {n_drawer_closed}/{self.all_episodes_completed}")

        # Log handle distance statistics (distance at moment drawer started moving)
        if len(self.all_min_handle_dist) > 0:
            valid_dists = [d for d in self.all_min_handle_dist if d != float('inf')]
            if valid_dists:
                dist_array = np.array(valid_dists)
                log.info(f"EE-handle distance at push: min={dist_array.min()*100:.1f}cm, max={dist_array.max()*100:.1f}cm, mean={dist_array.mean()*100:.1f}cm")

        # Compute episode step statistics (policy calls per episode)
        if len(self.all_episode_steps) > 0:
            steps_array = np.array(self.all_episode_steps)
            steps_min = int(np.min(steps_array))
            steps_max = int(np.max(steps_array))
            steps_mean = np.mean(steps_array)
            steps_std = np.std(steps_array)
            log.info(f"Policy calls per episode: min={steps_min}, max={steps_max}, mean={steps_mean:.1f}, std={steps_std:.1f}")
        else:
            steps_min = steps_max = 0
            steps_mean = steps_std = 0.0

        # Compute actions executed statistics (out of 87 total actions)
        if len(self.all_actions_executed) > 0:
            actions_array = np.array(self.all_actions_executed)
            actions_min = int(np.min(actions_array))
            actions_max = int(np.max(actions_array))
            actions_mean = np.mean(actions_array)
            actions_std = np.std(actions_array)
            log.info(f"Actions executed: min={actions_min}, max={actions_max}, mean={actions_mean:.1f}, std={actions_std:.1f}")
        else:
            actions_min = actions_max = 0
            actions_mean = actions_std = 0.0

        log.info(f"{'='*60}")

        # Save results
        np.savez(
            self.result_path,
            num_episode=self.all_episodes_completed,
            eval_success_rate=self.all_episodes_succeeded / max(1, self.all_episodes_completed),
            success_mean=success_mean,
            success_std=success_std,
            episode_steps=self.all_episode_steps,
            steps_min=steps_min,
            steps_max=steps_max,
            steps_mean=steps_mean,
            steps_std=steps_std,
            actions_executed=self.all_actions_executed,
            actions_min=actions_min,
            actions_max=actions_max,
            actions_mean=actions_mean,
            actions_std=actions_std,
            # Success criteria breakdown
            n_handle_contact=n_handle_contact,
            n_drawer_closed=n_drawer_closed,
            n_no_wall_collision=n_no_wall_collision,
        )

        # Plot all trajectories (use truncated trajectories for plotting)
        if len(self.all_ee_trajectories_plot) > 0:
            self._plot_trajectories(self.all_ee_trajectories_plot, self.all_success,
                                   self.all_mode_ids, success_mean, success_std,
                                   self.all_collision_idx)

            # Save trajectories (save both full and truncated for analysis)
            traj_path = self.result_path.replace('.npz', '_trajectories.pkl')
            with open(traj_path, 'wb') as f:
                pickle.dump({
                    'ee_trajectories': self.all_ee_trajectories,  # Full trajectories
                    'ee_trajectories_plot': self.all_ee_trajectories_plot,  # Truncated for plotting
                    'success': self.all_success,
                    'mode_ids': self.all_mode_ids,
                    'collision_idx': self.all_collision_idx,
                    'handle_contact': self.all_handle_contact,
                    'drawer_closed': self.all_drawer_closed,
                }, f)

        # Compute DTW distance for successful trajectories if demo is provided
        self._compute_dtw_distances()

    def _run_single_mode(self):
        """Run evaluation for single mode (n_noise_samples episodes)."""
        episodes_completed = 0
        current_ee_trajectory = [[]]
        episode_has_initial_pos = [False]
        current_episode_steps = np.zeros(1, dtype=int)
        current_handle_pos = [None]  # Track handle position per env
        current_success_step = [None]  # Track step when success occurred

        # Set up video path for first episode
        options_venv = [{}]
        if self.render_video:
            options_venv[0]["video_path"] = os.path.join(
                self.render_dir, f"eval_mode{self.current_mode_idx}_ep{episodes_completed}.mp4"
            )
        prev_obs_venv = self.reset_env_all(options_venv=options_venv)

        step = 0
        max_steps = self.n_steps

        while episodes_completed < self.n_noise_samples and step < max_steps:
            step += 1

            with torch.no_grad():
                cond = {"state": torch.from_numpy(prev_obs_venv["state"]).float().to(self.device)}
                samples = self.model(cond=cond, deterministic=True)
                output_venv = samples.trajectories.cpu().numpy()
            action_venv = output_venv[:, :self.act_steps]

            obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = self.venv.step(action_venv)

            # Collect EE
            env_idx = 0
            if not episode_has_initial_pos[env_idx] and "initial_ee_pos" in info_venv[env_idx]:
                initial_pos = info_venv[env_idx]["initial_ee_pos"]
                if initial_pos is not None:
                    initial_pos = np.array(initial_pos).flatten()
                    if initial_pos.shape[0] == 3:
                        current_ee_trajectory[env_idx].append(initial_pos)
                        episode_has_initial_pos[env_idx] = True

            if "all_primitive_ee_positions" in info_venv[env_idx]:
                for ee_pos in info_venv[env_idx]["all_primitive_ee_positions"]:
                    if ee_pos is not None and hasattr(ee_pos, '__len__'):
                        ee_pos = np.array(ee_pos).flatten()
                        if ee_pos.shape[0] == 3:
                            current_ee_trajectory[env_idx].append(ee_pos)

            # Track handle position for contact checking
            if "handle_pos" in info_venv[env_idx] and info_venv[env_idx]["handle_pos"] is not None:
                current_handle_pos[env_idx] = np.array(info_venv[env_idx]["handle_pos"])

            # Track success step (for plotting - stop at success)
            if "success_step" in info_venv[env_idx] and info_venv[env_idx]["success_step"] is not None:
                if current_success_step[env_idx] is None:
                    current_success_step[env_idx] = len(current_ee_trajectory[env_idx])

            current_episode_steps[env_idx] += 1

            # Check completion
            if terminated_venv[env_idx] or truncated_venv[env_idx]:
                episodes_completed += 1
                self.all_episodes_completed += 1
                # Check success from info dict (handles collect_full_trajectory mode)
                if 'success' in info_venv[env_idx]:
                    drawer_closed = info_venv[env_idx]['success']
                else:
                    drawer_closed = terminated_venv[env_idx] and not truncated_venv[env_idx]

                if len(current_ee_trajectory[env_idx]) > 0:
                    traj_array = np.array(current_ee_trajectory[env_idx])

                    # Check wall collision on the collected trajectory
                    wall_collision = False
                    collision_idx = None
                    if self.wall_config is not None and WALL_COLLISION_AVAILABLE:
                        wall_collision, collision_idx = check_ee_trajectory_wall_collision(
                            traj_array, self.wall_config
                        )

                    # Check if EE trajectory reached near the handle (proper approach)
                    handle_contact, contact_dist = check_handle_contact_from_trajectory(
                        traj_array, current_handle_pos[env_idx], self.handle_contact_threshold
                    )

                    # Success criteria: all 3 must be satisfied
                    # 1. No wall collision
                    # 2. Handle contact (EE trajectory reached near handle)
                    # 3. Drawer closed (task reward)
                    success = (not wall_collision) and handle_contact and drawer_closed

                    if success:
                        self.all_episodes_succeeded += 1

                    self.all_ee_trajectories.append(traj_array)
                    # Create truncated trajectory for plotting (up to success)
                    if current_success_step[env_idx] is not None:
                        traj_plot = traj_array[:current_success_step[env_idx]]
                    else:
                        traj_plot = traj_array  # No success, use full trajectory
                    self.all_ee_trajectories_plot.append(traj_plot)
                    self.all_success.append(success)
                    self.all_mode_ids.append(self.current_mode_idx)
                    self.all_episode_steps.append(current_episode_steps[env_idx])
                    self.all_collision_idx.append(collision_idx)
                    self.all_handle_contact.append(handle_contact)
                    self.all_drawer_closed.append(drawer_closed)
                    self.all_min_handle_dist.append(contact_dist)
                    # Track actions executed (from environment info)
                    if 'actions_executed' in info_venv[env_idx]:
                        self.all_actions_executed.append(info_venv[env_idx]['actions_executed'])

                current_ee_trajectory[env_idx] = []
                episode_has_initial_pos[env_idx] = False
                current_episode_steps[env_idx] = 0
                current_handle_pos[env_idx] = None  # Reset for next episode
                current_success_step[env_idx] = None  # Reset for next episode

                # Reset for next episode with new video path
                if episodes_completed < self.n_noise_samples:
                    options_venv = [{}]
                    if self.render_video:
                        options_venv[0]["video_path"] = os.path.join(
                            self.render_dir, f"eval_mode{self.current_mode_idx}_ep{episodes_completed}.mp4"
                        )
                    prev_obs_venv = self.reset_env_all(options_venv=options_venv)
                    step = 0
                continue

            prev_obs_venv = obs_venv

    def _plot_trajectories(self, ee_trajectories, trajectory_success, trajectory_mode_ids,
                          success_mean, success_std, trajectory_collision_idx=None):
        """
        Plot EE trajectories with 8 different colors for 8 modes.
        Matches style of RLBench_close_drawer/plots/create_trajectory_plots.py

        If wall collision occurred, trajectory is truncated at collision point.
        """
        if trajectory_collision_idx is None:
            trajectory_collision_idx = [None] * len(ee_trajectories)

        fig = plt.figure(figsize=(12, 10))
        ax = fig.add_subplot(111, projection='3d')

        from matplotlib.lines import Line2D
        from matplotlib.patches import Patch
        legend_handles = []
        modes_plotted = set()

        # Collect trajectory bounds for wall drawing
        all_x, all_y, all_z = [], [], []

        for idx, (trajectory, success, mode_id, collision_idx) in enumerate(
                zip(ee_trajectories, trajectory_success, trajectory_mode_ids, trajectory_collision_idx)):
            if len(trajectory) > 0:
                # Truncate trajectory at collision point if wall collision occurred
                if collision_idx is not None:
                    trajectory = trajectory[:collision_idx + 1]  # Include collision point

                color = MODE_COLORS[mode_id % len(MODE_COLORS)]
                x, y, z = trajectory[:, 0], trajectory[:, 1], trajectory[:, 2]
                all_x.extend(x)
                all_y.extend(y)
                all_z.extend(z)

                # Line style: solid for all (matching create_trajectory_plots.py)
                # Alpha: slightly lower for failed trajectories
                alpha = 0.7 if success else 0.4
                ax.plot(x, y, z, color=color, linewidth=1.2, alpha=alpha)

                # Mark start point
                ax.scatter(trajectory[0, 0], trajectory[0, 1], trajectory[0, 2],
                          color=color, s=20, marker='o', alpha=0.8)
                # Mark end point - use 'X' for collision, 'x' for normal end
                if collision_idx is not None:
                    # Red X marker for collision point
                    ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2],
                              color='red', s=50, marker='X', alpha=1.0)
                else:
                    ax.scatter(trajectory[-1, 0], trajectory[-1, 1], trajectory[-1, 2],
                              color=color, s=20, marker='x', alpha=0.8)

                # Track modes for legend
                modes_plotted.add(mode_id)

        # Draw wall if enabled
        if self.wall_config is not None and len(all_x) > 0:
            self._draw_wall_on_3d_plot(ax, all_x, all_y, all_z)

        # Add legend entries for each mode (matching create_trajectory_plots.py)
        for mode_id in sorted(modes_plotted):
            color = MODE_COLORS[mode_id % len(MODE_COLORS)]
            legend_handles.append(Line2D([0], [0], color=color, linewidth=2, label=f'Mode {mode_id}'))

        # Add wall to legend if enabled
        if self.wall_config is not None:
            wall_label = 'Wall (with opening)' if self.wall_config.get("opening") else 'Wall'
            legend_handles.append(Patch(facecolor='red', alpha=0.3, edgecolor='darkred', label=wall_label))
            # Add collision marker to legend if any collisions occurred
            if any(c is not None for c in trajectory_collision_idx):
                legend_handles.append(Line2D([0], [0], marker='X', color='red', linestyle='None',
                                            markersize=10, label='Wall Collision'))

        ax.legend(handles=legend_handles, loc='upper right', fontsize=9)
        ax.set_xlabel('X (m)', fontsize=11)
        ax.set_ylabel('Y (m)', fontsize=11)
        ax.set_zlabel('Z (m)', fontsize=11)

        # Build title with wall info
        title = f'{self.model_name} - Close Drawer\n'
        title += f'Success Rate: {success_mean*100:.1f}% ± {success_std*100:.1f}% '
        title += f'({sum(trajectory_success)}/{len(trajectory_success)} trajectories)'
        if self.wall_config is not None:
            title += f'\nWall Style {self.wall_style} (Y={self.wall_config["wall_y"]:.2f})'
        ax.set_title(title, fontsize=14)

        # Set equal aspect ratio for 3D plot (matching create_trajectory_plots.py)
        if len(ee_trajectories) > 0:
            all_points = np.vstack([t for t in ee_trajectories if len(t) > 0])
            max_range = np.array([
                all_points[:, 0].max() - all_points[:, 0].min(),
                all_points[:, 1].max() - all_points[:, 1].min(),
                all_points[:, 2].max() - all_points[:, 2].min()
            ]).max() / 2.0

            mid_x = (all_points[:, 0].max() + all_points[:, 0].min()) * 0.5
            mid_y = (all_points[:, 1].max() + all_points[:, 1].min()) * 0.5
            mid_z = (all_points[:, 2].max() + all_points[:, 2].min()) * 0.5

            ax.set_xlim(mid_x - max_range, mid_x + max_range)
            ax.set_ylim(mid_y - max_range, mid_y + max_range)
            ax.set_zlim(mid_z - max_range, mid_z + max_range)

        # View angle to match camera frame (same as visualize_wall_trajectories.py)
        ax.view_init(elev=30, azim=-135)

        plt.tight_layout()
        plot_path = os.path.join(self.render_dir, 'ee_trajectories_3d.png')
        plt.savefig(plot_path, dpi=200, bbox_inches='tight')
        log.info(f"Saved EE trajectory plot to {plot_path}")
        plt.close()
