"""
Evaluate policies on grasp (pick_up_cup) task with proper EE trajectory visualization.

Grasp task structure:
  - Full trajectory: reach -> descend -> grasp -> lift (107 steps total)
  - Approach angles: 0°, 90°, 180°, 270° (4 modes)
  - Fixed cup position: [0.25, -0.05, 0.82]

Supports:
- PDP (with mode-colored trajectories, 4 colors for 4 approach angles)
- DP, BC, BC-GMM, IBC (single color trajectories)

Plots saved to: {logdir}/render/ee_trajectories_3d.png
Title includes: Success Rate (mean ± std)
"""

import os
import sys
import numpy as np
import torch
import logging
import pickle
from pathlib import Path
from tqdm import tqdm
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

log = logging.getLogger(__name__)

# Add paths for imports
dppo_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, os.path.join(dppo_root, 'RLBench_grasp', 'encoder'))
sys.path.insert(0, os.path.join(dppo_root, 'RLBench_grasp', 'make_dataset'))
sys.path.append(dppo_root)

# Color palette for 4 approach angles
MODE_COLORS = [
    '#e41a1c',  # Red - 0°
    '#377eb8',  # Blue - 90°
    '#4daf4a',  # Green - 180°
    '#984ea3',  # Purple - 270°
]

from util.timer import Timer
from agent.eval.eval_agent import EvalAgent

# Import grasp config
try:
    from grasp_config import FIXED_CUP_POSITION, HOME_JOINTS
    GRASP_CONFIG_AVAILABLE = True
except ImportError:
    GRASP_CONFIG_AVAILABLE = False
    FIXED_CUP_POSITION = np.array([0.25, -0.05, 0.82])
    HOME_JOINTS = np.array([0, -0.3, 0, -2.0, 0, 1.8, 0.785])
    log.warning("Grasp config not available, using default values.")


def check_grasp_success(ee_trajectory, cup_pos, lift_threshold=0.05):
    """
    Check if the grasp was successful based on EE trajectory.

    Success criteria:
    1. EE reached close to cup (within grasp distance)
    2. EE lifted up after reaching cup

    Args:
        ee_trajectory: np.array of shape (N, 3), EE positions along trajectory
        cup_pos: np.array of shape (3,), cup position
        lift_threshold: float, minimum lift height above cup to count as success

    Returns:
        success: bool, True if grasp was successful
        min_dist_to_cup: float, minimum distance to cup during trajectory
        max_lift: float, maximum lift height above cup
    """
    if ee_trajectory is None or len(ee_trajectory) == 0:
        return False, float('inf'), 0.0

    ee_trajectory = np.asarray(ee_trajectory)
    cup_pos = np.asarray(cup_pos)

    # Find minimum distance to cup (XY plane)
    xy_distances = np.linalg.norm(ee_trajectory[:, :2] - cup_pos[:2], axis=1)
    min_dist_to_cup = np.min(xy_distances)

    # Find maximum lift height above cup
    z_above_cup = ee_trajectory[:, 2] - cup_pos[2]
    max_lift = np.max(z_above_cup)

    # Success if EE got close to cup and lifted
    reached_cup = min_dist_to_cup < 0.08  # Within 8cm
    lifted = max_lift > lift_threshold

    success = reached_cup and lifted
    return success, min_dist_to_cup, max_lift


def compute_grasp_angle(grasp_point, cup_pos):
    """
    Compute the approach angle of the grasp point relative to cup center.

    Angle is measured in degrees, counter-clockwise from positive X-axis.
    Returns angle in range [0, 360).

    Args:
        grasp_point: np.array of shape (3,), grasp point position (XYZ)
        cup_pos: np.array of shape (3,), cup center position (XYZ)

    Returns:
        angle_deg: float, angle in degrees [0, 360)
    """
    dx = grasp_point[0] - cup_pos[0]
    dy = grasp_point[1] - cup_pos[1]

    # atan2 returns angle in [-pi, pi], convert to [0, 360)
    angle_rad = np.arctan2(dy, dx)
    angle_deg = np.degrees(angle_rad)
    if angle_deg < 0:
        angle_deg += 360

    return angle_deg


def check_angle_match(grasp_angle, target_angle, tolerance=15.0):
    """
    Check if grasp angle matches target angle within tolerance.

    Handles wrap-around at 0°/360° boundary.

    Args:
        grasp_angle: float, actual grasp angle in degrees [0, 360)
        target_angle: float, target angle in degrees [0, 360)
        tolerance: float, allowed deviation in degrees (default ±15°)

    Returns:
        match: bool, True if within tolerance
        angle_diff: float, signed angular difference in degrees
    """
    # Compute angular difference handling wrap-around
    diff = grasp_angle - target_angle

    # Normalize to [-180, 180]
    while diff > 180:
        diff -= 360
    while diff < -180:
        diff += 360

    match = abs(diff) <= tolerance
    return match, diff


# Target angles for each grasp style (for finetuning adaptation experiments)
GRASP_STYLE_TARGET_ANGLES = {
    1: 0.0,    # Style 1: Mode 0° (train, episode0)
    2: 270.0,  # Style 2: Mode 270° (train, episode30)
    3: 45.0,   # Style 3: Test 45° (novel angle)
}


class EvalGraspAgent(EvalAgent):
    """
    Base evaluation agent for grasp (pick_up_cup) task.
    Collects EE trajectories and plots them with success rate in title.

    Success criteria:
    1. EE approached cup correctly
    2. Cup was lifted (task reward == 1.0)
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.model_name = cfg.get('model_name', 'Policy')
        self.cup_pos = np.array(cfg.get('cup_pos', FIXED_CUP_POSITION))
        self.lift_threshold = cfg.get('lift_threshold', 0.05)

        # Load normalization stats for denormalizing actions (to get grasp positions)
        normalization_path = cfg.get('normalization_path', None)
        if normalization_path is not None:
            norm_data = np.load(normalization_path)
            self.obs_min = norm_data['obs_min']
            self.obs_max = norm_data['obs_max']
            self.obs_range = self.obs_max - self.obs_min
            self.obs_range = np.where(self.obs_range < 1e-6, 1.0, self.obs_range)
            self.action_min = norm_data['action_min']
            self.action_max = norm_data['action_max']
            self.action_range = self.action_max - self.action_min
            self.action_range = np.where(self.action_range < 1e-6, 1.0, self.action_range)
        else:
            self.obs_min = None
            self.obs_max = None
            self.obs_range = None
            self.action_min = None
            self.action_max = None
            self.action_range = None

        # Grasp step index (around step 72, end of descend phase)
        self.grasp_step_idx = 72

        # x_T file for DP inversion (optimized initial noise)
        self.x_T_file = cfg.get('x_T_file', None)
        self.x_T = None
        if self.x_T_file is not None and os.path.exists(self.x_T_file):
            self.x_T = torch.from_numpy(np.load(self.x_T_file)).float().to(self.device)
            log.info(f"Loaded optimized x_T from {self.x_T_file}, shape: {self.x_T.shape}")

        # Selected component file for BC-GMM posterior selection
        self.selected_component_file = cfg.get('selected_component_file', None)
        self.selected_component = None
        if self.selected_component_file is not None and os.path.exists(self.selected_component_file):
            component_data = np.load(self.selected_component_file, allow_pickle=True).item()
            self.selected_component = component_data['selected_component']
            log.info(f"Loaded selected component k*={self.selected_component} from {self.selected_component_file}")

        # Demo trajectory file for IBC demo-initialized inference
        self.demo_trajectory_file = cfg.get('demo_trajectory_file', None)
        self.demo_trajectory = None
        self.demo_noise_std = cfg.get('demo_noise_std', 0.1)
        if self.demo_trajectory_file is not None and os.path.exists(self.demo_trajectory_file):
            self.demo_trajectory = torch.from_numpy(np.load(self.demo_trajectory_file)).float().to(self.device)
            log.info(f"Loaded demo trajectory from {self.demo_trajectory_file}, shape: {self.demo_trajectory.shape}")

        # Target grasp style for angle-based success check (finetuning adaptation)
        # grasp_style: 1=0°, 2=270°, 3=45°
        self.grasp_style = cfg.get('grasp_style', None)
        self.angle_tolerance = cfg.get('angle_tolerance', 15.0)  # ±15° default
        if self.grasp_style is not None:
            self.target_angle = GRASP_STYLE_TARGET_ANGLES.get(self.grasp_style, None)
            if self.target_angle is not None:
                log.info(f"Grasp style {self.grasp_style}: target angle = {self.target_angle}° (±{self.angle_tolerance}°)")
        else:
            self.target_angle = None

    def run(self):
        """Run evaluation and collect trajectories."""
        timer = Timer()

        # Prepare video paths
        options_venv = [{} for _ in range(self.n_envs)]
        if self.render_video:
            for env_ind in range(self.n_render):
                options_venv[env_ind]["video_path"] = os.path.join(
                    self.render_dir, f"eval_trial-{env_ind}.mp4"
                )

        # Reset env
        self.model.eval()
        firsts_trajs = np.zeros((self.n_steps + 1, self.n_envs))
        prev_obs_venv = self.reset_env_all(options_venv=options_venv)
        firsts_trajs[0] = 1
        reward_trajs = np.zeros((self.n_steps, self.n_envs))

        # Track EE trajectories
        ee_trajectories = []
        trajectory_success = []
        trajectory_style_success = []  # Style success: task success AND correct angle
        trajectory_grasp_angles = []   # Grasp angles for each trajectory
        trajectory_mode_ids = []
        current_ee_trajectory = [[] for _ in range(self.n_envs)]

        # Episode tracking
        episodes_completed = 0
        episodes_succeeded = 0
        episodes_style_succeeded = 0  # Count of style-correct successes
        current_episode_steps = np.zeros(self.n_envs, dtype=int)
        max_steps_per_episode = self.max_episode_steps // self.act_steps

        n_eval_episodes = self.n_eval_episodes if self.n_eval_episodes else self.n_envs * 10

        # Progress bars
        episode_pbar = tqdm(total=n_eval_episodes, desc="Episodes", position=0, leave=True, unit="ep")
        step_pbar = tqdm(total=max_steps_per_episode, desc="Current episode", position=1, leave=False, unit="step")

        step = 0
        current_mode_id = 0

        while episodes_completed < n_eval_episodes and step < self.n_steps:
            step += 1

            # Select action
            with torch.no_grad():
                cond = {
                    "state": torch.from_numpy(prev_obs_venv["state"])
                    .float()
                    .to(self.device)
                }

                # Handle different model types
                if hasattr(self, 'x_T') and self.x_T is not None:
                    # DP with optimized initial noise - pass via cond["noise_action"]
                    cond["noise_action"] = self.x_T
                    samples = self.model(cond=cond, deterministic=True)
                elif hasattr(self, 'selected_component') and self.selected_component is not None:
                    # BC-GMM with posterior selection
                    samples = self.model(cond=cond, deterministic=True, fixed_component=self.selected_component)
                elif hasattr(self, 'demo_trajectory') and self.demo_trajectory is not None:
                    # IBC with demo initialization
                    noisy_demo = self.demo_trajectory + self.demo_noise_std * torch.randn_like(self.demo_trajectory)
                    samples = self.model(cond=cond, deterministic=True, demo_trajectory=noisy_demo)
                else:
                    samples = self.model(cond=cond, deterministic=True)

            # Execute action
            output_venv = samples.trajectories.cpu().numpy()
            action_venv = output_venv[:, :self.act_steps]
            obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = self.venv.step(action_venv)

            done_venv = terminated_venv | truncated_venv
            reward_trajs[step - 1] = reward_venv
            firsts_trajs[step] = done_venv.astype(int)

            # Store full EE trajectory from info for later analysis
            # info_venv is a list of dicts (one per env) from SyncVectorEnv
            for env_idx in range(self.n_envs):
                env_info = info_venv[env_idx] if isinstance(info_venv, list) else info_venv
                if 'all_ee_positions' in env_info:
                    ee_positions = env_info['all_ee_positions']
                    if isinstance(ee_positions, list) and len(ee_positions) > 0:
                        # Store the full trajectory
                        current_ee_trajectory[env_idx] = [np.array(p) for p in ee_positions]
                        if step == 1:  # Log once
                            log.info(f"Collected {len(ee_positions)} EE positions for env {env_idx}")
                elif step == 1:
                    log.warning(f"all_ee_positions not found in info_venv. Keys: {list(env_info.keys())}")

            current_episode_steps += 1
            step_pbar.update(1)

            # Handle episode termination
            for env_idx in range(self.n_envs):
                if done_venv[env_idx] or current_episode_steps[env_idx] >= max_steps_per_episode:
                    # Episode ended
                    ee_traj = np.array(current_ee_trajectory[env_idx]) if current_ee_trajectory[env_idx] else None

                    # Check success based on reward
                    task_success = reward_venv[env_idx] > 0.5

                    # Compute grasp angle from trajectory
                    grasp_angle = None
                    style_success = False
                    if ee_traj is not None and len(ee_traj) > 0:
                        traj_arr = np.array(ee_traj)
                        if traj_arr.ndim > 1:
                            # Find grasp point (lowest Z position)
                            min_z_idx = np.argmin(traj_arr[:, 2])
                            grasp_point = traj_arr[min_z_idx]
                            grasp_angle = compute_grasp_angle(grasp_point, self.cup_pos)

                            # Check style success if target angle is defined
                            if task_success and self.target_angle is not None:
                                angle_match, angle_diff = check_angle_match(
                                    grasp_angle, self.target_angle, self.angle_tolerance
                                )
                                style_success = angle_match

                    ee_trajectories.append(ee_traj)
                    trajectory_success.append(task_success)
                    trajectory_style_success.append(style_success)
                    trajectory_grasp_angles.append(grasp_angle)
                    trajectory_mode_ids.append(current_mode_id)

                    episodes_completed += 1
                    if task_success:
                        episodes_succeeded += 1
                    if style_success:
                        episodes_style_succeeded += 1

                    # Update progress bar with both success rates
                    postfix = {
                        'task_sr': f'{episodes_succeeded}/{episodes_completed} ({100*episodes_succeeded/episodes_completed:.1f}%)'
                    }
                    if self.target_angle is not None:
                        postfix['style_sr'] = f'{episodes_style_succeeded}/{episodes_completed} ({100*episodes_style_succeeded/episodes_completed:.1f}%)'
                    episode_pbar.update(1)
                    episode_pbar.set_postfix(postfix)

                    # Reset for next episode
                    current_ee_trajectory[env_idx] = []
                    current_episode_steps[env_idx] = 0
                    step_pbar.reset()

                    # Note: SyncVectorEnv automatically resets done environments
                    # The new observation is already in obs_venv from the step() call

                    current_mode_id = (current_mode_id + 1) % 4  # Cycle through 4 modes

            prev_obs_venv = obs_venv

        episode_pbar.close()
        step_pbar.close()

        # Calculate statistics
        success_rate = np.mean(trajectory_success)
        style_success_rate = np.mean(trajectory_style_success) if self.target_angle is not None else None

        log.info(f"\n{'='*60}")
        log.info(f"{self.model_name} Evaluation Results")
        log.info(f"{'='*60}")
        log.info(f"Episodes: {episodes_completed}")
        log.info(f"Task Success Rate: {100*success_rate:.1f}%")
        if self.target_angle is not None:
            log.info(f"Style Success Rate: {100*style_success_rate:.1f}% (target angle: {self.target_angle}° ±{self.angle_tolerance}°)")
            # Log angle statistics for successful grasps
            successful_angles = [a for a, s in zip(trajectory_grasp_angles, trajectory_success) if s and a is not None]
            if successful_angles:
                log.info(f"  Grasp angles (successful): mean={np.mean(successful_angles):.1f}°, std={np.std(successful_angles):.1f}°")
                log.info(f"  Angle range: [{np.min(successful_angles):.1f}°, {np.max(successful_angles):.1f}°]")
        log.info(f"{'='*60}\n")

        # Save results
        save_data = {
            'ee_trajectories': np.array(ee_trajectories, dtype=object),
            'trajectory_success': np.array(trajectory_success),
            'trajectory_style_success': np.array(trajectory_style_success),
            'trajectory_grasp_angles': np.array(trajectory_grasp_angles, dtype=object),
            'trajectory_mode_ids': np.array(trajectory_mode_ids),
            'success_rate': success_rate,
            'model_name': self.model_name,
        }
        if self.target_angle is not None:
            save_data['style_success_rate'] = style_success_rate
            save_data['target_angle'] = self.target_angle
            save_data['angle_tolerance'] = self.angle_tolerance
        np.savez(self.result_path, **save_data)

        # Plot trajectories
        self._plot_trajectories(
            ee_trajectories,
            trajectory_success,
            trajectory_style_success,
            trajectory_grasp_angles,
            trajectory_mode_ids,
            success_rate,
            style_success_rate
        )

        return {"success_rate": success_rate, "style_success_rate": style_success_rate}

    def _plot_trajectories(self, ee_trajectories, trajectory_success, trajectory_style_success,
                           trajectory_grasp_angles, trajectory_mode_ids, success_rate, style_success_rate=None):
        """Plot grasp points in 2D top-down view. Single color for non-PDP methods."""
        fig, ax = plt.subplots(figsize=(8, 8))

        # Collect grasp points by success status
        # If style checking is enabled, distinguish between style_success and task-only success
        grasp_points = {'style_success': [], 'task_success': [], 'fail': []}

        for traj, task_success, style_success, mode_id in zip(
            ee_trajectories, trajectory_success, trajectory_style_success, trajectory_mode_ids
        ):
            if traj is None or len(traj) == 0:
                continue

            # Convert trajectory to array and find lowest Z position (grasp point)
            traj_arr = np.array(traj)
            if traj_arr.ndim == 1:
                grasp_point = traj_arr
            else:
                min_z_idx = np.argmin(traj_arr[:, 2])
                grasp_point = traj_arr[min_z_idx]

            if style_success:
                grasp_points['style_success'].append(grasp_point)
            elif task_success:
                grasp_points['task_success'].append(grasp_point)
            else:
                grasp_points['fail'].append(grasp_point)

        # Plot grasp points
        # Style success (green): lifted AND correct angle
        if grasp_points['style_success']:
            pts = np.array(grasp_points['style_success'])
            ax.scatter(pts[:, 0], pts[:, 1],
                      c='green', s=80, marker='o', alpha=0.7, edgecolors='black', linewidths=0.5, label='Style Success')

        # Task success only (blue): lifted but wrong angle
        if grasp_points['task_success']:
            pts = np.array(grasp_points['task_success'])
            ax.scatter(pts[:, 0], pts[:, 1],
                      c='blue', s=80, marker='o', alpha=0.7, edgecolors='black', linewidths=0.5, label='Task Success (wrong angle)')

        if grasp_points['fail']:
            pts = np.array(grasp_points['fail'])
            ax.scatter(pts[:, 0], pts[:, 1],
                      c='red', s=80, marker='x', alpha=0.5, linewidths=2, label='Fail')

        # Plot cup position (center)
        ax.scatter([self.cup_pos[0]], [self.cup_pos[1]],
                  c='purple', s=200, marker='*', label='Cup', zorder=10, edgecolors='black')

        # Draw reference circle around cup (gripper offset)
        circle_radius = 0.04
        theta = np.linspace(0, 2*np.pi, 100)
        circle_x = self.cup_pos[0] + circle_radius * np.cos(theta)
        circle_y = self.cup_pos[1] + circle_radius * np.sin(theta)
        ax.plot(circle_x, circle_y, 'k--', alpha=0.3, linewidth=1, label='Gripper offset circle')

        # Draw target angle indicator if defined
        if self.target_angle is not None:
            line_length = 0.08
            angle_rad = np.radians(self.target_angle)
            end_x = self.cup_pos[0] + line_length * np.cos(angle_rad)
            end_y = self.cup_pos[1] + line_length * np.sin(angle_rad)
            ax.plot([self.cup_pos[0], end_x], [self.cup_pos[1], end_y],
                   'g-', linewidth=3, alpha=0.8, label=f'Target: {self.target_angle}°')

            # Draw tolerance arc
            arc_angles = np.linspace(
                np.radians(self.target_angle - self.angle_tolerance),
                np.radians(self.target_angle + self.angle_tolerance),
                30
            )
            arc_radius = 0.06
            arc_x = self.cup_pos[0] + arc_radius * np.cos(arc_angles)
            arc_y = self.cup_pos[1] + arc_radius * np.sin(arc_angles)
            ax.plot(arc_x, arc_y, 'g--', linewidth=2, alpha=0.5)

        ax.set_xlabel('X (m)', fontsize=12)
        ax.set_ylabel('Y (m)', fontsize=12)

        # Build title with both success rates
        title = f'{self.model_name}: Grasp Points (Top-Down View)\nTask Success: {100*success_rate:.1f}%'
        if style_success_rate is not None:
            title += f' | Style Success: {100*style_success_rate:.1f}%'
        ax.set_title(title, fontsize=14)

        ax.legend(loc='upper right')
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        save_path = os.path.join(self.render_dir, 'grasp_points.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved grasp points plot to {save_path}")


class EvalGraspPDPAgent(EvalGraspAgent):
    """
    Evaluation agent for PDP (Parameterized Diffusion Policy) on grasp task.

    Uses encoder to get z embeddings from demo trajectories, then samples
    with different noise initializations for each z.
    """

    def __init__(self, cfg):
        super().__init__(cfg)

        # PDP specific configs
        self.z_list = cfg.get('z_list', [])
        self.z_file = cfg.get('z_file', None)
        self.encoder_checkpoint = cfg.get('encoder_checkpoint', None)
        self.dataset_path = cfg.get('dataset_path', None)
        self.n_noise_samples = cfg.get('n_noise_samples', 5)

        # Load z embeddings
        self.z_embeddings = self._load_z_embeddings()
        log.info(f"Loaded {len(self.z_embeddings)} z embeddings for evaluation")

    def _load_z_embeddings(self):
        """Load or compute z embeddings for evaluation."""
        # Option 1: Load from file
        if self.z_file is not None and os.path.exists(self.z_file):
            z_data = np.load(self.z_file)
            return torch.from_numpy(z_data).float().to(self.device)

        # Option 2: Use provided z_list
        if len(self.z_list) > 0:
            return torch.tensor(self.z_list, dtype=torch.float32, device=self.device)

        # Option 3: Encode demo trajectories using encoder
        if self.encoder_checkpoint is not None and self.dataset_path is not None:
            return self._encode_demo_trajectories()

        # Option 4: Default - sample 4 z embeddings on unit circle (4 approach angles)
        log.warning("No z source specified, using default unit circle z embeddings")
        angles = [0, np.pi/2, np.pi, 3*np.pi/2]  # 0°, 90°, 180°, 270°
        z_embeddings = [[np.cos(a), np.sin(a)] for a in angles]
        return torch.tensor(z_embeddings, dtype=torch.float32, device=self.device)

    def _encode_demo_trajectories(self):
        """Encode demo trajectories using pretrained encoder to get z embeddings."""
        # Import encoder modules - use absolute path to ensure correct import
        import importlib.util
        grasp_encoder_path = os.path.join(dppo_root, 'RLBench_grasp', 'encoder')

        # Load trajectory_encoder module
        spec = importlib.util.spec_from_file_location(
            "grasp_trajectory_encoder",
            os.path.join(grasp_encoder_path, "trajectory_encoder.py")
        )
        trajectory_encoder_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(trajectory_encoder_module)
        TrajectoryVAE = trajectory_encoder_module.TrajectoryVAE

        # Load train_grasp_encoder module for make_trajectory_relative
        spec = importlib.util.spec_from_file_location(
            "train_grasp_encoder",
            os.path.join(grasp_encoder_path, "train_grasp_encoder.py")
        )
        train_grasp_encoder_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(train_grasp_encoder_module)
        make_trajectory_relative = train_grasp_encoder_module.make_trajectory_relative

        # Load encoder
        checkpoint = torch.load(self.encoder_checkpoint, map_location=self.device, weights_only=False)
        config = checkpoint['config']

        encoder = TrajectoryVAE(
            state_dim=config['state_dim'],
            action_dim=config['action_dim'],
            hidden_dim=config['hidden_dim'],
            num_layers=config['num_layers'],
            num_heads=config['num_heads'],
            latent_dim=config['latent_dim'],
            horizon=config['horizon']
        )
        encoder.load_state_dict(checkpoint['model_state_dict'])
        encoder.to(self.device)
        encoder.eval()

        # Load dataset
        dataset = np.load(self.dataset_path, allow_pickle=False)
        states = torch.from_numpy(dataset['states']).float().to(self.device)
        actions = torch.from_numpy(dataset['actions']).float().to(self.device)
        traj_lengths = dataset['traj_lengths']

        # Get first trajectory from each mode (4 modes)
        # Compute cumulative indices to find trajectory boundaries
        cum_lengths = np.cumsum([0] + list(traj_lengths))

        z_embeddings = []
        demos_per_mode = 10

        for mode_idx in range(4):
            # Get first demo of this mode
            demo_idx = mode_idx * demos_per_mode  # Episode index
            demo_start = cum_lengths[demo_idx]  # Step index (cumulative)
            demo_end = cum_lengths[demo_idx + 1]

            demo_states = states[demo_start:demo_end].unsqueeze(0)
            demo_actions = actions[demo_start:demo_end].unsqueeze(0)

            # Convert to relative coordinates
            states_rel, actions_rel = make_trajectory_relative(demo_states, demo_actions)

            # Encode
            with torch.no_grad():
                z = encoder.encode(states_rel, actions_rel)
                z_embeddings.append(z.squeeze(0))

        return torch.stack(z_embeddings)

    def run(self):
        """Run evaluation with multiple z embeddings and noise samples."""
        timer = Timer()

        # Prepare video paths
        options_venv = [{} for _ in range(self.n_envs)]
        if self.render_video:
            for env_ind in range(self.n_render):
                options_venv[env_ind]["video_path"] = os.path.join(
                    self.render_dir, f"eval_trial-{env_ind}.mp4"
                )

        # Reset env
        self.model.eval()
        prev_obs_venv = self.reset_env_all(options_venv=options_venv)

        # Track EE trajectories
        ee_trajectories = []
        trajectory_success = []
        trajectory_style_success = []  # Style success: task success AND correct angle
        trajectory_grasp_angles = []   # Grasp angles for each trajectory
        trajectory_mode_ids = []
        current_ee_trajectory = [[] for _ in range(self.n_envs)]

        # Episode tracking
        episodes_completed = 0
        episodes_succeeded = 0
        episodes_style_succeeded = 0  # Count of style-correct successes
        current_episode_steps = np.zeros(self.n_envs, dtype=int)
        max_steps_per_episode = self.max_episode_steps // self.act_steps

        # Total episodes: num_z_embeddings * n_noise_samples
        n_z = len(self.z_embeddings)
        n_eval_episodes = n_z * self.n_noise_samples

        log.info(f"Evaluating {n_z} z embeddings x {self.n_noise_samples} noise samples = {n_eval_episodes} episodes")

        # Progress bars
        episode_pbar = tqdm(total=n_eval_episodes, desc="Episodes", position=0, leave=True, unit="ep")
        step_pbar = tqdm(total=max_steps_per_episode, desc="Current episode", position=1, leave=False, unit="step")

        step = 0
        current_z_idx = 0
        current_noise_sample = 0

        while episodes_completed < n_eval_episodes and step < self.n_steps:
            step += 1

            # Get current z embedding and set it on the model
            current_z = self.z_embeddings[current_z_idx].unsqueeze(0)  # (1, z_dim)
            self.model.current_z = current_z

            # Select action with z conditioning
            with torch.no_grad():
                cond = {
                    "state": torch.from_numpy(prev_obs_venv["state"])
                    .float()
                    .to(self.device)
                }
                samples = self.model(cond=cond, deterministic=True)

            # Execute action
            output_venv = samples.trajectories.cpu().numpy()
            action_venv = output_venv[:, :self.act_steps]
            obs_venv, reward_venv, terminated_venv, truncated_venv, info_venv = self.venv.step(action_venv)

            done_venv = terminated_venv | truncated_venv

            # Store full EE trajectory from info for later analysis
            # info_venv is a list of dicts (one per env) from SyncVectorEnv
            for env_idx in range(self.n_envs):
                env_info = info_venv[env_idx] if isinstance(info_venv, list) else info_venv
                if 'all_ee_positions' in env_info:
                    ee_positions = env_info['all_ee_positions']
                    if isinstance(ee_positions, list) and len(ee_positions) > 0:
                        current_ee_trajectory[env_idx] = [np.array(p) for p in ee_positions]
                        if step == 1:
                            log.info(f"PDP: Collected {len(ee_positions)} EE positions for env {env_idx}")
                elif step == 1:
                    log.warning(f"PDP: all_ee_positions not found. Keys: {list(env_info.keys())}")

            current_episode_steps += 1
            step_pbar.update(1)

            # Handle episode termination
            for env_idx in range(self.n_envs):
                if done_venv[env_idx] or current_episode_steps[env_idx] >= max_steps_per_episode:
                    ee_traj = np.array(current_ee_trajectory[env_idx]) if current_ee_trajectory[env_idx] else None
                    task_success = reward_venv[env_idx] > 0.5

                    # Compute grasp angle from trajectory
                    grasp_angle = None
                    style_success = False
                    if ee_traj is not None and len(ee_traj) > 0:
                        traj_arr = np.array(ee_traj)
                        if traj_arr.ndim > 1:
                            # Find grasp point (lowest Z position)
                            min_z_idx = np.argmin(traj_arr[:, 2])
                            grasp_point = traj_arr[min_z_idx]
                            grasp_angle = compute_grasp_angle(grasp_point, self.cup_pos)

                            # Check style success if target angle is defined
                            if task_success and self.target_angle is not None:
                                angle_match, angle_diff = check_angle_match(
                                    grasp_angle, self.target_angle, self.angle_tolerance
                                )
                                style_success = angle_match

                    ee_trajectories.append(ee_traj)
                    trajectory_success.append(task_success)
                    trajectory_style_success.append(style_success)
                    trajectory_grasp_angles.append(grasp_angle)
                    trajectory_mode_ids.append(current_z_idx)

                    episodes_completed += 1
                    if task_success:
                        episodes_succeeded += 1
                    if style_success:
                        episodes_style_succeeded += 1

                    # Update progress bar
                    postfix = {
                        'task_sr': f'{episodes_succeeded}/{episodes_completed} ({100*episodes_succeeded/episodes_completed:.1f}%)',
                        'z_idx': current_z_idx,
                        'noise': current_noise_sample
                    }
                    if self.target_angle is not None:
                        postfix['style_sr'] = f'{episodes_style_succeeded}/{episodes_completed}'
                    episode_pbar.update(1)
                    episode_pbar.set_postfix(postfix)

                    # Reset for next episode
                    current_ee_trajectory[env_idx] = []
                    current_episode_steps[env_idx] = 0
                    step_pbar.reset()

                    # Update z and noise sample indices
                    current_noise_sample += 1
                    if current_noise_sample >= self.n_noise_samples:
                        current_noise_sample = 0
                        current_z_idx += 1

                    # Note: SyncVectorEnv automatically resets done environments
                    # The new observation is already in obs_venv from the step() call

            prev_obs_venv = obs_venv

        episode_pbar.close()
        step_pbar.close()

        # Calculate statistics
        success_rate = np.mean(trajectory_success)
        style_success_rate = np.mean(trajectory_style_success) if self.target_angle is not None else None

        # Per-mode statistics
        log.info(f"\n{'='*60}")
        log.info(f"{self.model_name} Evaluation Results")
        log.info(f"{'='*60}")
        log.info(f"Total Episodes: {episodes_completed}")
        log.info(f"Task Success Rate: {100*success_rate:.1f}%")
        if self.target_angle is not None:
            log.info(f"Style Success Rate: {100*style_success_rate:.1f}% (target angle: {self.target_angle}° ±{self.angle_tolerance}°)")
            # Log angle statistics for successful grasps
            successful_angles = [a for a, s in zip(trajectory_grasp_angles, trajectory_success) if s and a is not None]
            if successful_angles:
                log.info(f"  Grasp angles (successful): mean={np.mean(successful_angles):.1f}°, std={np.std(successful_angles):.1f}°")
                log.info(f"  Angle range: [{np.min(successful_angles):.1f}°, {np.max(successful_angles):.1f}°]")
        log.info(f"\nPer-Mode Results:")
        for mode_idx in range(n_z):
            mode_mask = np.array(trajectory_mode_ids) == mode_idx
            mode_success = np.mean(np.array(trajectory_success)[mode_mask]) if np.any(mode_mask) else 0
            angle = mode_idx * 90
            log.info(f"  Mode {mode_idx} ({angle}°): {100*mode_success:.1f}% task success")
        log.info(f"{'='*60}\n")

        # Save results
        save_data = {
            'ee_trajectories': np.array(ee_trajectories, dtype=object),
            'trajectory_success': np.array(trajectory_success),
            'trajectory_style_success': np.array(trajectory_style_success),
            'trajectory_grasp_angles': np.array(trajectory_grasp_angles, dtype=object),
            'trajectory_mode_ids': np.array(trajectory_mode_ids),
            'z_embeddings': self.z_embeddings.cpu().numpy(),
            'success_rate': success_rate,
            'model_name': self.model_name,
        }
        if self.target_angle is not None:
            save_data['style_success_rate'] = style_success_rate
            save_data['target_angle'] = self.target_angle
            save_data['angle_tolerance'] = self.angle_tolerance
        np.savez(self.result_path, **save_data)

        # Plot trajectories
        self._plot_trajectories(
            ee_trajectories,
            trajectory_success,
            trajectory_style_success,
            trajectory_grasp_angles,
            trajectory_mode_ids,
            success_rate,
            style_success_rate
        )

        return {"success_rate": success_rate, "style_success_rate": style_success_rate}

    def _plot_trajectories(self, ee_trajectories, trajectory_success, trajectory_style_success,
                           trajectory_grasp_angles, trajectory_mode_ids, success_rate, style_success_rate=None):
        """Plot grasp points in 2D top-down view. Mode-colored for PDP."""
        fig, ax = plt.subplots(figsize=(8, 8))

        # Collect grasp points by mode and success status
        # For PDP with style checking: style_success (green), task_success (mode color), fail (mode color x)
        grasp_points = {
            'style_success': {i: [] for i in range(4)},
            'task_success': {i: [] for i in range(4)},
            'fail': {i: [] for i in range(4)}
        }

        for traj, task_success, style_success, mode_id in zip(
            ee_trajectories, trajectory_success, trajectory_style_success, trajectory_mode_ids
        ):
            if traj is None or len(traj) == 0:
                continue

            # Convert trajectory to array and find lowest Z position (grasp point)
            traj_arr = np.array(traj)
            if traj_arr.ndim == 1:
                grasp_point = traj_arr
            else:
                min_z_idx = np.argmin(traj_arr[:, 2])
                grasp_point = traj_arr[min_z_idx]

            if style_success:
                grasp_points['style_success'][mode_id % 4].append(grasp_point)
            elif task_success:
                grasp_points['task_success'][mode_id % 4].append(grasp_point)
            else:
                grasp_points['fail'][mode_id % 4].append(grasp_point)

        # Plot grasp points (mode-colored for PDP)
        for mode_id in range(4):
            color = MODE_COLORS[mode_id]

            # Style success: green circle
            if grasp_points['style_success'][mode_id]:
                pts = np.array(grasp_points['style_success'][mode_id])
                ax.scatter(pts[:, 0], pts[:, 1],
                          c='green', s=100, marker='o', alpha=0.8, edgecolors='black', linewidths=1)

            # Task success only: mode color
            if grasp_points['task_success'][mode_id]:
                pts = np.array(grasp_points['task_success'][mode_id])
                ax.scatter(pts[:, 0], pts[:, 1],
                          c=color, s=80, marker='o', alpha=0.7, edgecolors='black', linewidths=0.5)

            # Fail: mode color x
            if grasp_points['fail'][mode_id]:
                pts = np.array(grasp_points['fail'][mode_id])
                ax.scatter(pts[:, 0], pts[:, 1],
                          c=color, s=80, marker='x', alpha=0.5, linewidths=2)

        # Plot cup position (center)
        ax.scatter([self.cup_pos[0]], [self.cup_pos[1]],
                  c='purple', s=200, marker='*', label='Cup', zorder=10, edgecolors='black')

        # Draw reference circle around cup (gripper offset)
        circle_radius = 0.04
        theta = np.linspace(0, 2*np.pi, 100)
        circle_x = self.cup_pos[0] + circle_radius * np.cos(theta)
        circle_y = self.cup_pos[1] + circle_radius * np.sin(theta)
        ax.plot(circle_x, circle_y, 'k--', alpha=0.3, linewidth=1, label='Gripper offset circle')

        # Draw target angle indicator if defined
        if self.target_angle is not None:
            line_length = 0.08
            angle_rad = np.radians(self.target_angle)
            end_x = self.cup_pos[0] + line_length * np.cos(angle_rad)
            end_y = self.cup_pos[1] + line_length * np.sin(angle_rad)
            ax.plot([self.cup_pos[0], end_x], [self.cup_pos[1], end_y],
                   'g-', linewidth=3, alpha=0.8, label=f'Target: {self.target_angle}°')

            # Draw tolerance arc
            arc_angles = np.linspace(
                np.radians(self.target_angle - self.angle_tolerance),
                np.radians(self.target_angle + self.angle_tolerance),
                30
            )
            arc_radius = 0.06
            arc_x = self.cup_pos[0] + arc_radius * np.cos(arc_angles)
            arc_y = self.cup_pos[1] + arc_radius * np.sin(arc_angles)
            ax.plot(arc_x, arc_y, 'g--', linewidth=2, alpha=0.5)

        # Add legend for modes
        for i, color in enumerate(MODE_COLORS):
            angle = i * 90
            ax.scatter([], [], c=color, s=80, marker='o', label=f'{angle}°')
        ax.scatter([], [], c='green', s=100, marker='o', edgecolors='black', linewidths=1, label='Style Success')
        ax.scatter([], [], c='gray', s=80, marker='o', edgecolors='black', linewidths=0.5, label='Task Success')
        ax.scatter([], [], c='gray', s=80, marker='x', linewidths=2, label='Fail')

        ax.set_xlabel('X (m)', fontsize=12)
        ax.set_ylabel('Y (m)', fontsize=12)

        # Build title with both success rates
        title = f'{self.model_name}: Grasp Points (Top-Down View)\nTask Success: {100*success_rate:.1f}%'
        if style_success_rate is not None:
            title += f' | Style Success: {100*style_success_rate:.1f}%'
        ax.set_title(title, fontsize=14)

        ax.legend(loc='upper right')
        ax.set_aspect('equal')
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        save_path = os.path.join(self.render_dir, 'grasp_points.png')
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        plt.close()
        log.info(f"Saved grasp points plot to {save_path}")
