import argparse
import torch
import matplotlib.pyplot as plt
import numpy as np


from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

import seaborn as sns


def main(args):
    data = torch.load(f"{args.result_dir}/trajectory.pkl")

    sampled_x = data["x"]  # List of tensors: [tensor1, tensor2, ...]
    x0 = data["x0"]  # Shape: [n_samples, feature_dim] - ground truth
    y = data["y"]  # Shape: [n_samples, feature_dim] - measurements
    sampled_z = data["z"]  # List of tensors: [tensor1, tensor2, ...]

    # Convert list of tensors to numpy array
    if isinstance(sampled_x, list):
        sampled_x = (
            torch.stack(sampled_x, dim=1).cpu().numpy()
        )  # Stack along iteration dimension
    elif isinstance(sampled_x, torch.Tensor):
        sampled_x = sampled_x.cpu().numpy()

    if isinstance(sampled_z, list):
        sampled_z = (
            torch.stack(sampled_z, dim=1).cpu().numpy()
        )  # Stack along iteration dimension
    elif isinstance(sampled_z, torch.Tensor):
        sampled_z = sampled_z.cpu().numpy()

    if isinstance(x0, torch.Tensor):
        x0 = x0.cpu().numpy()
    if isinstance(y, torch.Tensor):
        y = y.cpu().numpy()

    n_samples, n_iterations, feature_dim = sampled_z.shape
    print(
        f"Data shape: {n_samples} samples, {n_iterations} iterations, {feature_dim} features"
    )

    if args.pca:
        # Prepare data for PCA - combine all points for consistent transformation
        all_data = []

        # Add ground truth and measurements
        all_data.append(x0)  # Ground truth
        all_data.append(y)  # Measurements

        # Add all iterations of sampled_x and sampled_z
        for i in range(n_iterations):
            all_data.append(sampled_x[:, i, :])
            all_data.append(sampled_z[:, i, :])

        combined_data = np.vstack(all_data)

        # Standardize and apply PCA
        scaler = StandardScaler()
        combined_data_scaled = scaler.fit_transform(combined_data)

        pca = PCA(n_components=2)
        combined_data_pca = pca.fit_transform(combined_data_scaled)

        print(f"PCA explained variance ratio: {pca.explained_variance_ratio_}")
        print(f"Total explained variance: {pca.explained_variance_ratio_.sum():.3f}")

        # Extract PCA-transformed data
        idx = 0
        x0_pca = combined_data_pca[idx : idx + n_samples]
        idx += n_samples
        y_pca = combined_data_pca[idx : idx + n_samples]
        idx += n_samples

        sampled_x_pca = np.zeros((n_samples, n_iterations, 2))
        sampled_z_pca = np.zeros((n_samples, n_iterations, 2))

        for i in range(n_iterations):
            sampled_x_pca[:, i, :] = combined_data_pca[idx : idx + n_samples]
            idx += n_samples
            sampled_z_pca[:, i, :] = combined_data_pca[idx : idx + n_samples]
            idx += n_samples

        # Interactive visualization
        class TrajectoryViewer:
            def __init__(self, data):
                # Apply PCA to reduce dimensionality for visualization
                all_x = data["sampled_x"].reshape(-1, data["sampled_x"].shape[-1])
                all_z = data["sampled_z"].reshape(-1, data["sampled_z"].shape[-1])
                all_data = np.vstack([all_x, all_z, data["x0"], data["y"]])

                pca = PCA(n_components=2)
                all_data_pca = pca.fit_transform(all_data)

                # Split back into components
                n_x = all_x.shape[0]
                n_z = all_z.shape[0]
                n_x0 = data["x0"].shape[0]

                sampled_x_pca = all_data_pca[:n_x].reshape(
                    data["sampled_x"].shape[:-1] + (2,)
                )
                sampled_z_pca = all_data_pca[n_x : n_x + n_z].reshape(
                    data["sampled_z"].shape[:-1] + (2,)
                )
                x0_pca = all_data_pca[n_x + n_z : n_x + n_z + n_x0]
                y_pca = all_data_pca[n_x + n_z + n_x0 :]

                self.sampled_x_pca = sampled_x_pca
                self.sampled_z_pca = sampled_z_pca
                self.x0_pca = x0_pca
                self.y_pca = y_pca
                self.pca = pca
                self.data = data  # Store reference to original data
                self.current_iteration = 0
                self.current_sample = 0
                self.n_samples, self.n_iterations, _ = sampled_x_pca.shape
                self.show_all = False

                # Calculate stable axis limits for all plots
                self._calculate_axis_limits()

                # Pre-compute all distances and trajectories for performance
                self._precompute_distances_and_trajectories()

                self.fig, self.axes = plt.subplots(2, 2, figsize=(15, 12))
                self.fig.canvas.mpl_connect("key_press_event", self.on_key_press)

                # Store plot objects for efficient updates
                self._initialize_plots()

                self.update_plots()
                print("Controls:")
                print("  Right arrow: Next iteration")
                print("  Left arrow: Previous iteration")
                print("  Up arrow: Next sample")
                print("  Down arrow: Previous sample")
                print("  'a': Show all trajectories")
                print("  'r': Reset to show single trajectory")

            def _calculate_axis_limits(self):
                """Calculate stable axis limits based on all data points"""
                # Combine all trajectory points
                all_x_points = self.sampled_x_pca.reshape(-1, 2)
                all_z_points = self.sampled_z_pca.reshape(-1, 2)
                all_points = np.vstack(
                    [all_x_points, all_z_points, self.x0_pca, self.y_pca]
                )

                # Add some margin
                margin = 0.1
                x_range = all_points[:, 0].max() - all_points[:, 0].min()
                y_range = all_points[:, 1].max() - all_points[:, 1].min()

                self.xlim = [
                    all_points[:, 0].min() - margin * x_range,
                    all_points[:, 0].max() + margin * x_range,
                ]
                self.ylim = [
                    all_points[:, 1].min() - margin * y_range,
                    all_points[:, 1].max() + margin * y_range,
                ]

            def _precompute_distances_and_trajectories(self):
                """Pre-compute all distances and trajectories for performance"""
                print(
                    "Pre-computing distances and trajectories for optimal performance..."
                )

                # Pre-compute all alternating trajectories for each sample
                self.alternating_trajectories = {}
                for sample_idx in range(self.n_samples):
                    trajectory = []
                    point_types = []
                    for step in range(self.n_iterations * 2):
                        if step % 2 == 0:  # Even steps: X points
                            x_idx = step // 2
                            if x_idx < self.sampled_x_pca.shape[1]:
                                trajectory.append(self.sampled_x_pca[sample_idx, x_idx])
                                point_types.append("X")
                        else:  # Odd steps: Z points
                            z_idx = step // 2
                            if z_idx < self.sampled_z_pca.shape[1]:
                                trajectory.append(self.sampled_z_pca[sample_idx, z_idx])
                                point_types.append("Z")
                    self.alternating_trajectories[sample_idx] = (
                        trajectory,
                        point_types,
                    )

                # Pre-compute all distances to ground truth and measurements
                self.x_distances_to_gt = []
                self.z_distances_to_gt = []
                self.x_distances_to_y = []
                self.z_distances_to_y = []

                # For all samples
                for i in range(self.n_iterations):
                    if i < self.sampled_x_pca.shape[1]:
                        # X distances to ground truth (mean across samples)
                        x_dist = np.mean(
                            np.linalg.norm(
                                self.sampled_x_pca[:, i, :] - self.x0_pca, axis=1
                            )
                        )
                        self.x_distances_to_gt.append((i * 2, x_dist))

                        # X distances to measurements (mean across samples)
                        x_y_dist = np.mean(
                            np.linalg.norm(
                                self.sampled_x_pca[:, i, :] - self.y_pca, axis=1
                            )
                        )
                        self.x_distances_to_y.append((i * 2, x_y_dist))

                    if i < self.sampled_z_pca.shape[1]:
                        # Z distances to ground truth (mean across samples)
                        z_dist = np.mean(
                            np.linalg.norm(
                                self.sampled_z_pca[:, i, :] - self.x0_pca, axis=1
                            )
                        )
                        self.z_distances_to_gt.append((i * 2 + 1, z_dist))

                        # Z distances to measurements (mean across samples)
                        z_y_dist = np.mean(
                            np.linalg.norm(
                                self.sampled_z_pca[:, i, :] - self.y_pca, axis=1
                            )
                        )
                        self.z_distances_to_y.append((i * 2 + 1, z_y_dist))

                # Pre-compute individual sample distances
                self.sample_x_distances_to_gt = {}
                self.sample_z_distances_to_gt = {}
                self.sample_x_distances_to_y = {}
                self.sample_z_distances_to_y = {}

                for sample_idx in range(self.n_samples):
                    x_to_gt = []
                    z_to_gt = []
                    x_to_y = []
                    z_to_y = []

                    for i in range(self.n_iterations):
                        if i < self.sampled_x_pca.shape[1]:
                            x_gt_dist = np.linalg.norm(
                                self.sampled_x_pca[sample_idx, i, :]
                                - self.x0_pca[sample_idx]
                            )
                            x_y_dist = np.linalg.norm(
                                self.sampled_x_pca[sample_idx, i, :]
                                - self.y_pca[sample_idx]
                            )
                            x_to_gt.append((i * 2, x_gt_dist))
                            x_to_y.append((i * 2, x_y_dist))

                        if i < self.sampled_z_pca.shape[1]:
                            z_gt_dist = np.linalg.norm(
                                self.sampled_z_pca[sample_idx, i, :]
                                - self.x0_pca[sample_idx]
                            )
                            z_y_dist = np.linalg.norm(
                                self.sampled_z_pca[sample_idx, i, :]
                                - self.y_pca[sample_idx]
                            )
                            z_to_gt.append((i * 2 + 1, z_gt_dist))
                            z_to_y.append((i * 2 + 1, z_y_dist))

                    self.sample_x_distances_to_gt[sample_idx] = x_to_gt
                    self.sample_z_distances_to_gt[sample_idx] = z_to_gt
                    self.sample_x_distances_to_y[sample_idx] = x_to_y
                    self.sample_z_distances_to_y[sample_idx] = z_to_y

                # Pre-compute step lengths
                self.x_to_z_lengths = []
                self.z_to_x_lengths = []

                # Calculate all X -> Z transitions (mean across samples)
                for i in range(self.n_iterations):
                    if (
                        i < self.sampled_x_pca.shape[1]
                        and i < self.sampled_z_pca.shape[1]
                    ):
                        step_lengths = np.linalg.norm(
                            self.sampled_z_pca[:, i, :] - self.sampled_x_pca[:, i, :],
                            axis=1,
                        )
                        mean_length = np.mean(step_lengths)
                        self.x_to_z_lengths.append((i * 2 + 1, mean_length))

                # Calculate all Z -> X transitions (mean across samples)
                for i in range(self.n_iterations - 1):
                    if (
                        i < self.sampled_z_pca.shape[1]
                        and (i + 1) < self.sampled_x_pca.shape[1]
                    ):
                        step_lengths = np.linalg.norm(
                            self.sampled_x_pca[:, i + 1, :]
                            - self.sampled_z_pca[:, i, :],
                            axis=1,
                        )
                        mean_length = np.mean(step_lengths)
                        self.z_to_x_lengths.append(((i + 1) * 2, mean_length))

                # Pre-compute individual sample step lengths
                self.sample_x_to_z_lengths = {}
                self.sample_z_to_x_lengths = {}

                for sample_idx in range(self.n_samples):
                    x_to_z = []
                    z_to_x = []

                    # X -> Z step lengths for current sample
                    for i in range(self.n_iterations):
                        if (
                            i < self.sampled_x_pca.shape[1]
                            and i < self.sampled_z_pca.shape[1]
                        ):
                            step_length = np.linalg.norm(
                                self.sampled_z_pca[sample_idx, i, :]
                                - self.sampled_x_pca[sample_idx, i, :]
                            )
                            x_to_z.append((i * 2 + 1, step_length))

                    # Z -> X step lengths for current sample
                    for i in range(self.n_iterations - 1):
                        if (
                            i < self.sampled_z_pca.shape[1]
                            and (i + 1) < self.sampled_x_pca.shape[1]
                        ):
                            step_length = np.linalg.norm(
                                self.sampled_x_pca[sample_idx, i + 1, :]
                                - self.sampled_z_pca[sample_idx, i, :]
                            )
                            z_to_x.append(((i + 1) * 2, step_length))

                    self.sample_x_to_z_lengths[sample_idx] = x_to_z
                    self.sample_z_to_x_lengths[sample_idx] = z_to_x

                print("Pre-computation complete!")

            def _initialize_plots(self):
                """Initialize static plot elements for efficient updates"""
                # Clear all axes
                for ax in self.axes.flat:
                    ax.clear()

                # Initialize plot containers for efficient updates
                self.plot_objects = {}

                # Plot 1: Trajectory plot containers
                ax = self.axes[0, 0]
                self.plot_objects["traj_points"] = []
                self.plot_objects["traj_lines"] = []
                self.plot_objects["gt_scatter"] = None
                self.plot_objects["meas_scatter"] = None

                # Plot 2: Current step plot containers
                ax = self.axes[0, 1]
                self.plot_objects["current_scatter"] = None
                self.plot_objects["current_highlight"] = None

                # Plot 3: Distance analysis containers
                ax = self.axes[1, 0]
                self.plot_objects["dist_lines"] = []
                self.plot_objects["dist_highlights"] = []

                # Plot 4: Step length containers
                ax = self.axes[1, 1]
                self.plot_objects["step_lines"] = []
                self.plot_objects["step_highlights"] = []
                self.plot_objects["rho_line"] = None

            def on_key_press(self, event):
                step_size = (
                    10
                    if event.key
                    in ["shift+up", "shift+down", "shift+left", "shift+right"]
                    or (hasattr(event, "shift") and event.shift)
                    else 1
                )

                if event.key == "right" or event.key == "shift+right":
                    self.current_iteration = min(
                        self.current_iteration + step_size, self.n_iterations - 1
                    )
                elif event.key == "left" or event.key == "shift+left":
                    self.current_iteration = max(self.current_iteration - step_size, 0)
                elif event.key == "up" or event.key == "shift+up":
                    self.current_sample = min(
                        self.current_sample + step_size, self.n_samples - 1
                    )
                elif event.key == "down" or event.key == "shift+down":
                    self.current_sample = max(self.current_sample - step_size, 0)
                elif event.key == "a":
                    self.show_all = True
                elif event.key == "r":
                    self.show_all = False
                else:
                    return

                self.update_plots()

            def update_plots(self):
                """Efficiently update plots without full recomputation"""
                # Clear all axes
                for ax in self.axes.flat:
                    ax.clear()

                # Plot 1: Alternating X-Z trajectory (optimized)
                self._update_trajectory_plot()

                # Plot 2: Current step points for all samples (optimized)
                self._update_current_step_plot()

                # Plot 3: Distance analysis (optimized using pre-computed data)
                self._update_distance_plot()

                # Plot 4: Step length analysis (optimized using pre-computed data)
                self._update_step_length_plot()

                plt.tight_layout()
                self.fig.canvas.draw()

            def _update_trajectory_plot(self):
                """Update trajectory plot efficiently"""
                ax = self.axes[0, 0]

                if not self.show_all:
                    sample_idx = self.current_sample
                    trajectory, point_types = self.alternating_trajectories[sample_idx]

                    # Limit to current iteration
                    max_steps = min(self.current_iteration + 1, len(trajectory))
                    current_points = trajectory[:max_steps]
                    current_types = point_types[:max_steps]

                    # Plot points and connections efficiently
                    for i, (point, point_type) in enumerate(
                        zip(current_points, current_types)
                    ):
                        color = "blue" if point_type == "X" else "red"
                        size = 80 if i == len(current_points) - 1 else 50
                        edge = "yellow" if i == len(current_points) - 1 else "black"
                        linewidth = 3 if i == len(current_points) - 1 else 1

                        ax.scatter(
                            point[0],
                            point[1],
                            color=color,
                            s=size,
                            edgecolor=edge,
                            linewidth=linewidth,
                            alpha=0.8,
                        )

                        # Connect to previous point
                        if i > 0:
                            prev_point = current_points[i - 1]
                            line_color = (
                                "red" if current_types[i - 1] == "X" else "blue"
                            )
                            ax.plot(
                                [prev_point[0], point[0]],
                                [prev_point[1], point[1]],
                                color=line_color,
                                alpha=0.7,
                                linewidth=2,
                            )

                    # Ground truth and measurement (static)
                    ax.scatter(
                        self.x0_pca[sample_idx, 0],
                        self.x0_pca[sample_idx, 1],
                        marker="*",
                        s=200,
                        color="gold",
                        alpha=0.9,
                        edgecolor="black",
                        label="Ground Truth",
                        zorder=5,
                    )
                    ax.scatter(
                        self.y_pca[sample_idx, 0],
                        self.y_pca[sample_idx, 1],
                        marker="s",
                        s=120,
                        color="green",
                        alpha=0.9,
                        edgecolor="black",
                        label="Measurement",
                        zorder=5,
                    )

                    current_type = "X" if self.current_iteration % 2 == 0 else "Z"
                    ax.set_title(
                        f"X-Z Trajectory - Step {self.current_iteration} (at {current_type}), Sample {self.current_sample}"
                    )
                    ax.legend()
                else:
                    # Show multiple trajectories efficiently
                    for sample_idx in range(min(10, self.n_samples)):
                        trajectory, _ = self.alternating_trajectories[sample_idx]
                        max_steps = min(self.current_iteration + 1, len(trajectory))
                        current_points = trajectory[:max_steps]

                        for i, point in enumerate(current_points):
                            color = "blue" if i % 2 == 0 else "red"
                            ax.scatter(point[0], point[1], color=color, s=20, alpha=0.6)

                            if i > 0:
                                prev_point = current_points[i - 1]
                                line_color = "red" if i % 2 == 1 else "blue"
                                ax.plot(
                                    [prev_point[0], point[0]],
                                    [prev_point[1], point[1]],
                                    color=line_color,
                                    alpha=0.4,
                                    linewidth=1,
                                )

                    ax.set_title(
                        f"All X-Z Trajectories - Step {self.current_iteration}"
                    )

                ax.set_xlabel(f"PC1 ({self.pca.explained_variance_ratio_[0]:.2f})")
                ax.set_ylabel(f"PC2 ({self.pca.explained_variance_ratio_[1]:.2f})")
                ax.grid(True, alpha=0.3)
                ax.set_xlim(self.xlim)
                ax.set_ylim(self.ylim)

            def _update_current_step_plot(self):
                """Update current step plot efficiently"""
                ax = self.axes[0, 1]

                if self.current_iteration % 2 == 0:  # Currently at X
                    x_idx = self.current_iteration // 2
                    if x_idx < self.sampled_x_pca.shape[1]:
                        ax.scatter(
                            self.sampled_x_pca[:, x_idx, 0],
                            self.sampled_x_pca[:, x_idx, 1],
                            alpha=0.6,
                            label=f"X at step {self.current_iteration}",
                            s=40,
                            color="blue",
                        )
                else:  # Currently at Z
                    z_idx = self.current_iteration // 2
                    if z_idx < self.sampled_z_pca.shape[1]:
                        ax.scatter(
                            self.sampled_z_pca[:, z_idx, 0],
                            self.sampled_z_pca[:, z_idx, 1],
                            alpha=0.6,
                            label=f"Z at step {self.current_iteration}",
                            s=40,
                            color="red",
                        )

                # Static elements
                ax.scatter(
                    self.x0_pca[:, 0],
                    self.x0_pca[:, 1],
                    marker="*",
                    s=100,
                    color="gold",
                    label="Ground Truth",
                    alpha=0.8,
                )
                ax.scatter(
                    self.y_pca[:, 0],
                    self.y_pca[:, 1],
                    marker="s",
                    s=60,
                    color="green",
                    label="Measurements",
                    alpha=0.8,
                )

                if not self.show_all:
                    # Highlight current sample
                    if self.current_iteration % 2 == 0:  # Currently at X
                        x_idx = self.current_iteration // 2
                        if x_idx < self.sampled_x_pca.shape[1]:
                            ax.scatter(
                                self.sampled_x_pca[self.current_sample, x_idx, 0],
                                self.sampled_x_pca[self.current_sample, x_idx, 1],
                                s=150,
                                color="blue",
                                edgecolor="yellow",
                                linewidth=3,
                                zorder=5,
                            )
                    else:  # Currently at Z
                        z_idx = self.current_iteration // 2
                        if z_idx < self.sampled_z_pca.shape[1]:
                            ax.scatter(
                                self.sampled_z_pca[self.current_sample, z_idx, 0],
                                self.sampled_z_pca[self.current_sample, z_idx, 1],
                                s=150,
                                color="red",
                                edgecolor="yellow",
                                linewidth=3,
                                zorder=5,
                            )

                current_type = "X" if self.current_iteration % 2 == 0 else "Z"
                ax.set_title(
                    f"All Samples at Step {self.current_iteration} ({current_type})"
                )
                ax.set_xlabel(f"PC1 ({self.pca.explained_variance_ratio_[0]:.2f})")
                ax.set_ylabel(f"PC2 ({self.pca.explained_variance_ratio_[1]:.2f})")
                ax.legend()
                ax.grid(True, alpha=0.3)
                ax.set_xlim(self.xlim)
                ax.set_ylim(self.ylim)

            def _update_distance_plot(self):
                """Update distance plot using pre-computed data"""
                ax = self.axes[1, 0]

                # Plot mean distances (pre-computed)
                if self.x_distances_to_gt:
                    x_steps, x_dists = zip(*self.x_distances_to_gt)
                    ax.plot(
                        x_steps,
                        x_dists,
                        "o-",
                        color="blue",
                        label="X distance to ground truth (mean)",
                        markersize=4,
                        alpha=0.7,
                    )

                if self.z_distances_to_gt:
                    z_steps, z_dists = zip(*self.z_distances_to_gt)
                    ax.plot(
                        z_steps,
                        z_dists,
                        "s-",
                        color="red",
                        label="Z distance to ground truth (mean)",
                        markersize=4,
                        alpha=0.7,
                    )

                if self.x_distances_to_y:
                    x_y_steps, x_y_dists = zip(*self.x_distances_to_y)
                    ax.plot(
                        x_y_steps,
                        x_y_dists,
                        "o--",
                        color="lightblue",
                        label="X distance to measurement (mean)",
                        markersize=3,
                        alpha=0.6,
                    )

                if self.z_distances_to_y:
                    z_y_steps, z_y_dists = zip(*self.z_distances_to_y)
                    ax.plot(
                        z_y_steps,
                        z_y_dists,
                        "s--",
                        color="lightcoral",
                        label="Z distance to measurement (mean)",
                        markersize=3,
                        alpha=0.6,
                    )

                # Plot current sample distances (pre-computed)
                current_sample_x_to_gt = self.sample_x_distances_to_gt[
                    self.current_sample
                ]
                current_sample_z_to_gt = self.sample_z_distances_to_gt[
                    self.current_sample
                ]
                current_sample_x_to_y = self.sample_x_distances_to_y[
                    self.current_sample
                ]
                current_sample_z_to_y = self.sample_z_distances_to_y[
                    self.current_sample
                ]

                if current_sample_x_to_gt:
                    x_steps_sample, x_dists_sample = zip(*current_sample_x_to_gt)
                    ax.plot(
                        x_steps_sample,
                        x_dists_sample,
                        "o-",
                        color="darkblue",
                        label=f"X sample {self.current_sample} (to ground truth)",
                        markersize=3,
                        alpha=0.8,
                        linewidth=1,
                    )

                if current_sample_z_to_gt:
                    z_steps_sample, z_dists_sample = zip(*current_sample_z_to_gt)
                    ax.plot(
                        z_steps_sample,
                        z_dists_sample,
                        "s-",
                        color="darkred",
                        label=f"Z sample {self.current_sample} (to ground truth)",
                        markersize=3,
                        alpha=0.8,
                        linewidth=1,
                    )

                # Highlight current step
                if self.current_iteration % 2 == 0:  # Currently at X
                    x_idx = self.current_iteration // 2
                    if x_idx < len(self.x_distances_to_gt):
                        ax.scatter(
                            self.x_distances_to_gt[x_idx][0],
                            self.x_distances_to_gt[x_idx][1],
                            s=150,
                            color="blue",
                            edgecolor="yellow",
                            linewidth=3,
                            zorder=5,
                        )
                    if x_idx < len(current_sample_x_to_gt):
                        ax.scatter(
                            current_sample_x_to_gt[x_idx][0],
                            current_sample_x_to_gt[x_idx][1],
                            s=100,
                            color="darkblue",
                            edgecolor="yellow",
                            linewidth=2,
                            zorder=6,
                        )
                else:  # Currently at Z
                    z_idx = self.current_iteration // 2
                    if z_idx < len(self.z_distances_to_gt):
                        ax.scatter(
                            self.z_distances_to_gt[z_idx][0],
                            self.z_distances_to_gt[z_idx][1],
                            s=150,
                            color="red",
                            edgecolor="yellow",
                            linewidth=3,
                            zorder=5,
                        )
                    if z_idx < len(current_sample_z_to_gt):
                        ax.scatter(
                            current_sample_z_to_gt[z_idx][0],
                            current_sample_z_to_gt[z_idx][1],
                            s=100,
                            color="darkred",
                            edgecolor="yellow",
                            linewidth=2,
                            zorder=6,
                        )

                ax.set_title("Complete Convergence Analysis")
                ax.set_xlabel("Algorithm Step")
                ax.set_ylabel("Mean Distance")
                ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
                ax.grid(True, alpha=0.3)
                ax.set_xlim(-0.5, self.n_iterations * 2 - 0.5)

            def _update_step_length_plot(self):
                """Update step length plot using pre-computed data"""
                ax = self.axes[1, 1]

                # Plot mean step lengths (pre-computed)
                if self.x_to_z_lengths:
                    red_steps, red_lengths = zip(*self.x_to_z_lengths)
                    ax.plot(
                        red_steps,
                        red_lengths,
                        "ro-",
                        label="X→Z step length (mean)",
                        alpha=0.7,
                        markersize=4,
                    )

                if self.z_to_x_lengths:
                    blue_steps, blue_lengths = zip(*self.z_to_x_lengths)
                    ax.plot(
                        blue_steps,
                        blue_lengths,
                        "bo-",
                        label="Z→X step length (mean)",
                        alpha=0.7,
                        markersize=4,
                    )

                # Plot current sample step lengths (pre-computed)
                current_sample_x_to_z = self.sample_x_to_z_lengths[self.current_sample]
                current_sample_z_to_x = self.sample_z_to_x_lengths[self.current_sample]

                if current_sample_x_to_z:
                    red_steps_sample, red_lengths_sample = zip(*current_sample_x_to_z)
                    ax.plot(
                        red_steps_sample,
                        red_lengths_sample,
                        "ro-",
                        label=f"X→Z sample {self.current_sample}",
                        alpha=0.8,
                        markersize=3,
                        linewidth=1,
                    )

                if current_sample_z_to_x:
                    blue_steps_sample, blue_lengths_sample = zip(*current_sample_z_to_x)
                    ax.plot(
                        blue_steps_sample,
                        blue_lengths_sample,
                        "bo-",
                        label=f"Z→X sample {self.current_sample}",
                        alpha=0.8,
                        markersize=3,
                        linewidth=1,
                    )

                # Add rho_schedule if available
                if "rho_schedule" in self.data:
                    rho_schedule = self.data["rho_schedule"]
                    if isinstance(rho_schedule, torch.Tensor):
                        rho_schedule = rho_schedule.cpu().numpy()

                    ax2 = ax.twinx()
                    rho_iterations = list(range(len(rho_schedule)))
                    rho_plot_steps = [i * 2 for i in rho_iterations]

                    ax2.plot(
                        rho_plot_steps,
                        rho_schedule,
                        "g--",
                        label="ρ schedule",
                        alpha=0.8,
                        linewidth=2,
                        marker="o",
                        markersize=3,
                    )
                    ax2.set_ylabel("ρ schedule", color="green")
                    ax2.tick_params(axis="y", labelcolor="green")

                    # Highlight current rho value
                    current_iteration_idx = self.current_iteration // 2
                    if current_iteration_idx < len(rho_schedule):
                        current_step = current_iteration_idx * 2
                        ax2.scatter(
                            current_step,
                            rho_schedule[current_iteration_idx],
                            s=150,
                            color="green",
                            edgecolor="yellow",
                            linewidth=3,
                            zorder=5,
                        )

                    # Combine legends
                    lines1, labels1 = ax.get_legend_handles_labels()
                    lines2, labels2 = ax2.get_legend_handles_labels()
                    ax.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
                else:
                    ax.legend()

                # Highlight current step length
                if self.current_iteration > 0:
                    if self.current_iteration % 2 == 1:  # Just completed X->Z
                        step_idx = (self.current_iteration - 1) // 2
                        if step_idx < len(self.x_to_z_lengths):
                            step, length = self.x_to_z_lengths[step_idx]
                            ax.scatter(
                                step,
                                length,
                                s=150,
                                color="red",
                                edgecolor="yellow",
                                linewidth=3,
                                zorder=5,
                            )
                        if step_idx < len(current_sample_x_to_z):
                            step, length = current_sample_x_to_z[step_idx]
                            ax.scatter(
                                step,
                                length,
                                s=100,
                                color="darkred",
                                edgecolor="yellow",
                                linewidth=2,
                                zorder=6,
                            )
                    elif (
                        self.current_iteration % 2 == 0 and self.current_iteration > 0
                    ):  # Just completed Z->X
                        step_idx = (self.current_iteration // 2) - 1
                        if step_idx < len(self.z_to_x_lengths):
                            step, length = self.z_to_x_lengths[step_idx]
                            ax.scatter(
                                step,
                                length,
                                s=150,
                                color="blue",
                                edgecolor="yellow",
                                linewidth=3,
                                zorder=5,
                            )
                        if step_idx < len(current_sample_z_to_x):
                            step, length = current_sample_z_to_x[step_idx]
                            ax.scatter(
                                step,
                                length,
                                s=100,
                                color="darkblue",
                                edgecolor="yellow",
                                linewidth=2,
                                zorder=6,
                            )

                ax.set_title("Complete Step Lengths and ρ Schedule")
                ax.set_xlabel("Algorithm Step")
                ax.set_ylabel("Mean Step Length")
                ax.grid(True, alpha=0.3)
                ax.set_xlim(-0.5, self.n_iterations * 2 - 0.5)

        # Update data dict to include processed arrays
        data_dict = {"sampled_x": sampled_x, "sampled_z": sampled_z, "x0": x0, "y": y}
        # Copy any additional data
        for key, value in data.items():
            if key not in data_dict:
                data_dict[key] = value

        viewer = TrajectoryViewer(data_dict)
        plt.show()

    else:
        # Without PCA - just plot in original space (first 2 dimensions)
        print("Plotting first 2 dimensions of original space")

        fig, axes = plt.subplots(2, 2, figsize=(15, 12))

        # Similar plots but using original dimensions
        ax = axes[0, 0]
        for sample_idx in range(min(10, n_samples)):
            trajectory = sampled_x[sample_idx, :, :2]  # First 2 dims
            ax.plot(trajectory[:, 0], trajectory[:, 1], "o-", alpha=0.7, markersize=3)
            ax.scatter(
                x0[sample_idx, 0],
                x0[sample_idx, 1],
                marker="*",
                s=100,
                color="red",
                alpha=0.8,
            )
            ax.scatter(
                y[sample_idx, 0],
                y[sample_idx, 1],
                marker="s",
                s=50,
                color="green",
                alpha=0.8,
            )

        ax.set_title("Sampled X Trajectories (Original Space)")
        ax.set_xlabel("Dimension 1")
        ax.set_ylabel("Dimension 2")
        ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--result_dir",
        type=str,
        help="Path to the folder where results are saved",
    )
    parser.add_argument(
        "--pca",
        action="store_true",
    )
    args = parser.parse_args()

    main(args)
