import argparse
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


from torch.utils.data import DataLoader

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

import seaborn as sns


def main(args):
    data = torch.load(f"{args.result_dir}/trajectory.pkl")

    # Plot the stream field for a given iteration
    sampled_x = data["x"]  # List of tensors: [tensor1, tensor2, ...]
    x0 = data["x0"]  # Shape: [n_samples, feature_dim] - ground truth
    y = data["y"]  # Shape: [n_samples, feature_dim] - measurements
    sampled_z = data["z"]  # List of tensors: [tensor1, tensor2, ...]

    # Convert list of tensors to numpy array
    if isinstance(sampled_x, list):
        sampled_x = (
            torch.stack(sampled_x, dim=1).cpu().numpy()
        )  # Stack along iteration dimension
    elif isinstance(sampled_x, torch.Tensor):
        sampled_x = sampled_x.cpu().numpy()

    if isinstance(sampled_z, list):
        sampled_z = (
            torch.stack(sampled_z, dim=1).cpu().numpy()
        )  # Stack along iteration dimension
    elif isinstance(sampled_z, torch.Tensor):
        sampled_z = sampled_z.cpu().numpy()

    if isinstance(x0, torch.Tensor):
        x0 = x0.cpu().numpy()
    if isinstance(y, torch.Tensor):
        y = y.cpu().numpy()

    n_samples, n_iterations, feature_dim = sampled_z.shape
    print(
        f"Data shape: {n_samples} samples, {n_iterations} iterations, {feature_dim} features"
    )

    if args.pca:
        # Prepare data for PCA - combine all points for consistent transformation
        all_data = []

        # Add ground truth and measurements
        all_data.append(x0)  # Ground truth
        all_data.append(y)  # Measurements

        # Add all iterations of sampled_x and sampled_z
        for i in range(n_iterations):
            all_data.append(sampled_x[:, i, :])
            all_data.append(sampled_z[:, i, :])

        combined_data = np.vstack(all_data)

        # Standardize and apply PCA
        scaler = StandardScaler()
        combined_data_scaled = scaler.fit_transform(combined_data)

        pca = PCA(n_components=2)
        combined_data_pca = pca.fit_transform(combined_data_scaled)

        print(f"PCA explained variance ratio: {pca.explained_variance_ratio_}")
        print(f"Total explained variance: {pca.explained_variance_ratio_.sum():.3f}")

        # Extract PCA-transformed data
        idx = 0
        x0_pca = combined_data_pca[idx : idx + n_samples]
        idx += n_samples
        y_pca = combined_data_pca[idx : idx + n_samples]
        idx += n_samples

        sampled_x_pca = np.zeros((n_samples, n_iterations, 2))
        sampled_z_pca = np.zeros((n_samples, n_iterations, 2))

        for i in range(n_iterations):
            sampled_x_pca[:, i, :] = combined_data_pca[idx : idx + n_samples]
            idx += n_samples
            sampled_z_pca[:, i, :] = combined_data_pca[idx : idx + n_samples]
            idx += n_samples

        # Create vector field visualization
        create_vector_field_plots(
            sampled_x_pca, sampled_z_pca, x0_pca, y_pca, pca, n_iterations
        )
    else:
        # Use first 2 dimensions for vector field without PCA
        sampled_x_2d = sampled_x[:, :, :2]
        sampled_z_2d = sampled_z[:, :, :2]
        x0_2d = x0[:, :2]
        y_2d = y[:, :2]

        create_vector_field_plots(
            sampled_x_2d, sampled_z_2d, x0_2d, y_2d, None, n_iterations
        )


class VectorFieldViewer:
    """Interactive viewer for vector field plots"""

    def __init__(self, sampled_x, sampled_z, x0, y, pca, n_iterations):
        self.sampled_x = sampled_x
        self.sampled_z = sampled_z
        self.x0 = x0
        self.y = y
        self.pca = pca
        self.n_iterations = n_iterations
        self.current_iteration = 0
        self.current_sample = 0
        self.n_samples = sampled_x.shape[0]
        self.show_all_samples = True

        # Calculate axis limits for consistent scaling
        self._calculate_axis_limits()

        self.fig, self.ax = plt.subplots(1, 1, figsize=(12, 8))
        self.fig.canvas.mpl_connect("key_press_event", self.on_key_press)

        self.update_plot()
        print("Controls:")
        print("  Right arrow: Next iteration")
        print("  Left arrow: Previous iteration")
        print("  Up arrow: Next sample (or show all)")
        print("  Down arrow: Previous sample (or show all)")
        print("  Home: First iteration")
        print("  End: Last iteration")
        print("  'a': Toggle all samples view")
        print("  's': Show single sample view")

    def _calculate_axis_limits(self):
        """Calculate stable axis limits based on all data points"""
        all_x_points = np.vstack(
            [
                self.sampled_x[:, i, :]
                for i in range(min(self.n_iterations, self.sampled_x.shape[1]))
            ]
        )
        all_z_points = np.vstack(
            [
                self.sampled_z[:, i, :]
                for i in range(min(self.n_iterations, self.sampled_z.shape[1]))
            ]
        )
        all_points = np.vstack([all_x_points, all_z_points, self.x0, self.y])

        # Add margin
        margin = 0.1
        x_range = all_points[:, 0].max() - all_points[:, 0].min()
        y_range = all_points[:, 1].max() - all_points[:, 1].min()

        self.xlim = [
            all_points[:, 0].min() - margin * x_range,
            all_points[:, 0].max() + margin * x_range,
        ]
        self.ylim = [
            all_points[:, 1].min() - margin * y_range,
            all_points[:, 1].max() + margin * y_range,
        ]

    def on_key_press(self, event):
        if event.key == "right":
            self.current_iteration = min(
                self.current_iteration + 1, self.n_iterations - 1
            )
        elif event.key == "left":
            self.current_iteration = max(self.current_iteration - 1, 0)
        elif event.key == "up":
            if self.show_all_samples:
                self.show_all_samples = False
                self.current_sample = 0
            else:
                self.current_sample = min(self.current_sample + 1, self.n_samples - 1)
        elif event.key == "down":
            if not self.show_all_samples:
                self.current_sample = max(self.current_sample - 1, 0)
                if self.current_sample == 0:
                    self.show_all_samples = True
            else:
                self.current_sample = self.n_samples - 1
                self.show_all_samples = False
        elif event.key == "home":
            self.current_iteration = 0
        elif event.key == "end":
            self.current_iteration = self.n_iterations - 1
        elif event.key == "a":
            self.show_all_samples = True
        elif event.key == "s":
            self.show_all_samples = False
        else:
            return

        self.update_plot()

    def update_plot(self):
        self.ax.clear()

        i = self.current_iteration

        if self.show_all_samples:
            # Show all samples
            # X→Z displacement vectors
            if i < self.sampled_x.shape[1] and i < self.sampled_z.shape[1]:
                x_points = self.sampled_x[:, i, :]  # Starting points (X)
                z_points = self.sampled_z[:, i, :]  # Ending points (Z)

                # Calculate displacement vectors
                displacement_x_to_z = z_points - x_points

                # Plot X→Z vectors with arrows
                self.ax.quiver(
                    x_points[:, 0],
                    x_points[:, 1],
                    displacement_x_to_z[:, 0],
                    displacement_x_to_z[:, 1],
                    angles="xy",
                    scale_units="xy",
                    scale=1,
                    color="red",
                    alpha=0.6,
                    width=0.002,
                    label="X→Z displacement",
                    headwidth=3,
                    headlength=3,
                    headaxislength=2,
                )

            # Z→X displacement vectors (for next iteration)
            if i < self.sampled_z.shape[1] and (i + 1) < self.sampled_x.shape[1]:
                z_points = self.sampled_z[:, i, :]  # Starting points (Z)
                x_points_next = self.sampled_x[:, i + 1, :]  # Ending points (next X)

                # Calculate displacement vectors
                displacement_z_to_x = x_points_next - z_points

                # Plot Z→X vectors with arrows
                self.ax.quiver(
                    z_points[:, 0],
                    z_points[:, 1],
                    displacement_z_to_x[:, 0],
                    displacement_z_to_x[:, 1],
                    angles="xy",
                    scale_units="xy",
                    scale=1,
                    color="blue",
                    alpha=0.6,
                    width=0.002,
                    label="Z→X displacement",
                    headwidth=3,
                    headlength=3,
                    headaxislength=2,
                )

            # Plot reference points
            if i < self.sampled_x.shape[1]:
                self.ax.scatter(
                    self.sampled_x[:, i, 0],
                    self.sampled_x[:, i, 1],
                    c="darkblue",
                    s=20,
                    alpha=0.6,
                    label="X points",
                )
            if i < self.sampled_z.shape[1]:
                self.ax.scatter(
                    self.sampled_z[:, i, 0],
                    self.sampled_z[:, i, 1],
                    c="darkred",
                    s=20,
                    alpha=0.6,
                    label="Z points",
                )

            title_suffix = " (All Samples)"
        else:
            # Show single sample
            sample_idx = self.current_sample

            # X→Z displacement vectors
            if i < self.sampled_x.shape[1] and i < self.sampled_z.shape[1]:
                x_point = self.sampled_x[
                    sample_idx : sample_idx + 1, i, :
                ]  # Single sample
                z_point = self.sampled_z[
                    sample_idx : sample_idx + 1, i, :
                ]  # Single sample

                # Calculate displacement vector
                displacement_x_to_z = z_point - x_point

                # Plot X→Z vector with arrow
                self.ax.quiver(
                    x_point[:, 0],
                    x_point[:, 1],
                    displacement_x_to_z[:, 0],
                    displacement_x_to_z[:, 1],
                    angles="xy",
                    scale_units="xy",
                    scale=1,
                    color="red",
                    alpha=0.8,
                    width=0.004,
                    label="X→Z displacement",
                    headwidth=3,
                    headlength=3,
                    headaxislength=2,
                )

            # Z→X displacement vectors (for next iteration)
            if i < self.sampled_z.shape[1] and (i + 1) < self.sampled_x.shape[1]:
                z_point = self.sampled_z[
                    sample_idx : sample_idx + 1, i, :
                ]  # Single sample
                x_point_next = self.sampled_x[
                    sample_idx : sample_idx + 1, i + 1, :
                ]  # Single sample

                # Calculate displacement vector
                displacement_z_to_x = x_point_next - z_point

                # Plot Z→X vector with arrow
                self.ax.quiver(
                    z_point[:, 0],
                    z_point[:, 1],
                    displacement_z_to_x[:, 0],
                    displacement_z_to_x[:, 1],
                    angles="xy",
                    scale_units="xy",
                    scale=1,
                    color="blue",
                    alpha=0.8,
                    width=0.004,
                    label="Z→X displacement",
                    headwidth=3,
                    headlength=3,
                    headaxislength=2,
                )

            # Plot reference points for current sample
            if i < self.sampled_x.shape[1]:
                self.ax.scatter(
                    self.sampled_x[sample_idx, i, 0],
                    self.sampled_x[sample_idx, i, 1],
                    c="darkblue",
                    s=60,
                    alpha=0.9,
                    label="X point",
                    edgecolor="black",
                    linewidth=1,
                )
            if i < self.sampled_z.shape[1]:
                self.ax.scatter(
                    self.sampled_z[sample_idx, i, 0],
                    self.sampled_z[sample_idx, i, 1],
                    c="darkred",
                    s=60,
                    alpha=0.9,
                    label="Z point",
                    edgecolor="black",
                    linewidth=1,
                )

            # Plot ground truth and measurement for current sample
            self.ax.scatter(
                self.x0[sample_idx, 0],
                self.x0[sample_idx, 1],
                marker="*",
                s=120,
                c="gold",
                alpha=0.9,
                edgecolor="black",
                label="Ground Truth",
                linewidth=1,
            )
            self.ax.scatter(
                self.y[sample_idx, 0],
                self.y[sample_idx, 1],
                marker="s",
                s=80,
                c="green",
                alpha=0.9,
                edgecolor="black",
                label="Measurement",
                linewidth=1,
            )

            title_suffix = f" (Sample {sample_idx})"

        # Plot ground truth and measurements for all samples view
        if self.show_all_samples:
            self.ax.scatter(
                self.x0[:, 0],
                self.x0[:, 1],
                marker="*",
                s=40,
                c="gold",
                alpha=0.7,
                edgecolor="black",
                label="Ground Truth",
                linewidth=0.5,
            )
            self.ax.scatter(
                self.y[:, 0],
                self.y[:, 1],
                marker="s",
                s=30,
                c="green",
                alpha=0.7,
                edgecolor="black",
                label="Measurements",
                linewidth=0.5,
            )

        # Set labels and title
        if self.pca is not None:
            self.ax.set_xlabel(f"PC1 ({self.pca.explained_variance_ratio_[0]:.2f})")
            self.ax.set_ylabel(f"PC2 ({self.pca.explained_variance_ratio_[1]:.2f})")
            self.ax.set_title(f"Vector Field - Iteration {i} (PCA Space){title_suffix}")
        else:
            self.ax.set_xlabel("Dimension 1")
            self.ax.set_ylabel("Dimension 2")
            self.ax.set_title(
                f"Vector Field - Iteration {i} (Original Space){title_suffix}"
            )

        self.ax.grid(True, alpha=0.3)
        self.ax.set_xlim(self.xlim)
        self.ax.set_ylim(self.ylim)
        self.ax.legend()

        self.fig.canvas.draw()


def create_vector_field_plots(sampled_x, sampled_z, x0, y, pca, n_iterations):
    """Create interactive vector field plots showing X→Z and Z→X displacements"""

    # Create interactive viewer
    viewer = VectorFieldViewer(sampled_x, sampled_z, x0, y, pca, n_iterations)
    plt.show()

    # Create a summary plot with all iterations overlaid
    create_summary_vector_field(sampled_x, sampled_z, x0, y, pca, n_iterations)

    # Create a summary plot with all iterations overlaid
    create_summary_vector_field(sampled_x, sampled_z, x0, y, pca, n_iterations)


def create_summary_vector_field(sampled_x, sampled_z, x0, y, pca, n_iterations):
    """Create a summary plot with all vector fields overlaid"""

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

    ax1.set_aspect("equal", adjustable="box")
    ax2.set_aspect("equal", adjustable="box")
    # Plot 1: All X→Z displacements
    colors_x_to_z = plt.cm.Reds(np.linspace(0.3, 1.0, n_iterations))
    for i in range(min(n_iterations, sampled_x.shape[1], sampled_z.shape[1])):
        x_points = sampled_x[:, i, :]
        z_points = sampled_z[:, i, :]
        displacement = z_points - x_points

        ax1.quiver(
            x_points[:, 0],
            x_points[:, 1],
            displacement[:, 0],
            displacement[:, 1],
            angles="xy",
            scale_units="xy",
            scale=1,
            color=colors_x_to_z[i],
            alpha=0.6,
            width=0.002,
            label=f"Iter {i}" if i < 5 else "",
            headwidth=0,
            headlength=0,
            headaxislength=0,
        )

    # Plot reference points for first iteration
    ax1.scatter(
        sampled_x[:, 0, 0],
        sampled_x[:, 0, 1],
        c="darkblue",
        s=30,
        alpha=0.8,
        label="X points (start)",
    )
    ax1.scatter(
        x0[:, 0],
        x0[:, 1],
        marker="*",
        s=80,
        c="gold",
        alpha=0.9,
        edgecolor="black",
        label="Ground Truth",
    )
    ax1.scatter(
        y[:, 0],
        y[:, 1],
        marker="s",
        s=50,
        c="green",
        alpha=0.9,
        edgecolor="black",
        label="Measurements",
    )

    if pca is not None:
        ax1.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.2f})")
        ax1.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.2f})")
        ax1.set_title("All X→Z Displacements (PCA Space)")
    else:
        ax1.set_xlabel("Dimension 1")
        ax1.set_ylabel("Dimension 2")
        ax1.set_title("All X→Z Displacements (Original Space)")

    ax1.grid(True, alpha=0.3)
    ax1.set_aspect("equal", adjustable="box")
    ax1.legend()

    # Plot 2: All Z→X displacements
    colors_z_to_x = plt.cm.Blues(np.linspace(0.3, 1.0, n_iterations - 1))
    for i in range(min(n_iterations - 1, sampled_z.shape[1])):
        if (i + 1) < sampled_x.shape[1]:
            z_points = sampled_z[:, i, :]
            x_points_next = sampled_x[:, i + 1, :]
            displacement = x_points_next - z_points

            ax2.quiver(
                z_points[:, 0],
                z_points[:, 1],
                displacement[:, 0],
                displacement[:, 1],
                angles="xy",
                scale_units="xy",
                scale=1,
                color=colors_z_to_x[i],
                alpha=0.6,
                width=0.002,
                label=f"Iter {i}" if i < 5 else "",
                headwidth=0,
                headlength=0,
                headaxislength=0,
            )

    # Plot reference points for first iteration
    if sampled_z.shape[1] > 0:
        ax2.scatter(
            sampled_z[:, 0, 0],
            sampled_z[:, 0, 1],
            c="darkred",
            s=30,
            alpha=0.8,
            label="Z points (start)",
        )
    ax2.scatter(
        x0[:, 0],
        x0[:, 1],
        marker="*",
        s=80,
        c="gold",
        alpha=0.9,
        edgecolor="black",
        label="Ground Truth",
    )
    ax2.scatter(
        y[:, 0],
        y[:, 1],
        marker="s",
        s=50,
        c="green",
        alpha=0.9,
        edgecolor="black",
        label="Measurements",
    )

    if pca is not None:
        ax2.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]:.2f})")
        ax2.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]:.2f})")
        ax2.set_title("All Z→X Displacements (PCA Space)")
    else:
        ax2.set_xlabel("Dimension 1")
        ax2.set_ylabel("Dimension 2")
        ax2.set_title("All Z→X Displacements (Original Space)")

    # ax2.grid(True, alpha=0.3)
    ax2.set_aspect("equal", adjustable="box")
    # ax2.legend()

    plt.tight_layout()
    plt.savefig(f"vector_field_summary.pdf", bbox_inches="tight")
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--result_dir",
        type=str,
        help="Path to the model folder.",
    )
    parser.add_argument(
        "--pca",
        action="store_true",
    )
    args = parser.parse_args()

    main(args)
