import argparse
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from mixed_diffusion.data_loading.data_loading import get_data

from torch.utils.data import DataLoader

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import umap.umap_ as umap

import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go


def main(args):

    args.sample_size = 10000

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )
    train, test, _ = get_data(args, transform)

    if train[0][0].shape[0] > 2:  # Multidimensional data
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        # More efficient extraction from TensorDataset
        x_data = train.tensors[0].numpy()
        labels = train.tensors[1].numpy()
        sns.heatmap(x_data, cmap="viridis", ax=axes[0])
        plt.tight_layout()
        plt.show()

        scaler = StandardScaler()
        X_std = scaler.fit_transform(x_data)

        # Choose dimensionality reduction method
        if args.umap:
            print("Using UMAP for dimensionality reduction")
            # For UMAP, first apply PCA if data is too high-dimensional
            if X_std.shape[1] > 15:
                pca_pre = PCA(n_components=15, random_state=0)
                X_pca_pre = pca_pre.fit_transform(X_std)
            else:
                X_pca_pre = X_std

            umap_reducer = umap.UMAP(
                n_components=min(10, X_pca_pre.shape[1]),
                random_state=0,
                n_neighbors=15,
                min_dist=0.1,
            )
            X_reduced = umap_reducer.fit_transform(X_pca_pre)

            # For UMAP, create a simple component table
            var_table = pd.DataFrame(
                {
                    "Component": [
                        f"UMAP{i+1}" for i in range(min(10, X_reduced.shape[1]))
                    ],
                    "Type": ["UMAP"] * min(10, X_reduced.shape[1]),
                }
            )
            reduction_method = "UMAP"
        else:
            print("Using PCA for dimensionality reduction")
            # ---------------------------------------------------------------------------
            # 3.  RUN PCA  (keep all 10 PCs, but we'll plot the first 2)
            # ---------------------------------------------------------------------------
            pca = PCA(n_components=10, random_state=0)
            X_reduced = pca.fit_transform(X_std)

            # ---------------------------------------------------------------------------
            # 4.  REPORT VARIANCE EXPLAINED
            # ---------------------------------------------------------------------------
            var_table = pd.DataFrame(
                {
                    "PC": np.arange(1, 11),
                    "Explained_Var%": np.round(pca.explained_variance_ratio_ * 100, 2),
                }
            )
            reduction_method = "PCA"

        print(var_table.to_string(index=False))

        # ---------------------------------------------------------------------------
        # 5.  2-D SCATTER PLOT
        # ---------------------------------------------------------------------------
        plt.figure(figsize=(7, 5))
        scatter = plt.scatter(
            X_reduced[:, 0],
            X_reduced[:, 1],
            c=labels,
            s=18,
            alpha=0.75,
            cmap="tab10",
            edgecolors="k",
        )

        if reduction_method == "PCA":
            plt.xlabel(f"PC1 ({var_table.loc[0, 'Explained_Var%']} %)")
            plt.ylabel(f"PC2 ({var_table.loc[1, 'Explained_Var%']} %)")
            plt.title(
                f"PCA projection of {X_reduced.shape[1]}-D data (coloured by mode)"
            )
        else:  # UMAP
            plt.xlabel("UMAP1")
            plt.ylabel("UMAP2")
            plt.title(f"UMAP projection of {x_data.shape[1]}-D data (coloured by mode)")

        plt.tight_layout()
        plt.show()

        # ---------------------------------------------------------------------------
        # 6.  3-D INTERACTIVE SCATTER PLOT
        # ---------------------------------------------------------------------------
        print(f"\nCreating interactive 3D {reduction_method} plot...")

        # Create a DataFrame for Plotly
        if reduction_method == "PCA":
            plot_df = pd.DataFrame(
                {
                    "PC1": X_reduced[:, 0],
                    "PC2": X_reduced[:, 1],
                    "PC3": X_reduced[:, 2],
                    "Labels": (
                        labels.astype(str)
                        if hasattr(labels, "astype")
                        else [str(l) for l in labels]
                    ),
                }
            )
            title = (
                f"Interactive 3D PCA of {x_data.shape[1]}-D Data (colored by labels)"
            )
            axis_labels = {
                "PC1": f'PC1 ({var_table.loc[0, "Explained_Var%"]:.1f}% variance)',
                "PC2": f'PC2 ({var_table.loc[1, "Explained_Var%"]:.1f}% variance)',
                "PC3": f'PC3 ({var_table.loc[2, "Explained_Var%"]:.1f}% variance)',
            }
        else:  # UMAP
            plot_df = pd.DataFrame(
                {
                    "PC1": X_reduced[:, 0],  # Keep same column names for consistency
                    "PC2": X_reduced[:, 1],
                    "PC3": X_reduced[:, 2],
                    "Labels": (
                        labels.astype(str)
                        if hasattr(labels, "astype")
                        else [str(l) for l in labels]
                    ),
                }
            )
            title = (
                f"Interactive 3D UMAP of {x_data.shape[1]}-D Data (colored by labels)"
            )
            axis_labels = {
                "PC1": "UMAP1",
                "PC2": "UMAP2",
                "PC3": "UMAP3",
            }

        # Create interactive 3D scatter plot
        fig_3d = px.scatter_3d(
            plot_df,
            x="PC1",
            y="PC2",
            z="PC3",
            color="Labels",
            title=title,
            labels=axis_labels,
            opacity=0.7,
        )

        # Update layout for better visualization
        if reduction_method == "PCA":
            scene_axes = dict(
                xaxis_title=f'PC1 ({var_table.loc[0, "Explained_Var%"]:.1f}% variance)',
                yaxis_title=f'PC2 ({var_table.loc[1, "Explained_Var%"]:.1f}% variance)',
                zaxis_title=f'PC3 ({var_table.loc[2, "Explained_Var%"]:.1f}% variance)',
                camera=dict(eye=dict(x=1.2, y=1.2, z=1.2)),
            )
        else:  # UMAP
            scene_axes = dict(
                xaxis_title="UMAP1",
                yaxis_title="UMAP2",
                zaxis_title="UMAP3",
                camera=dict(eye=dict(x=1.2, y=1.2, z=1.2)),
            )

        fig_3d.update_layout(
            width=900,
            height=700,
            scene=scene_axes,
            legend=dict(yanchor="top", y=0.99, xanchor="left", x=1.01),
        )

        # Update traces for better point visibility
        fig_3d.update_traces(
            marker=dict(size=4, line=dict(width=0.5, color="DarkSlateGrey"))
        )

        # Show the interactive plot
        html_filename = f"3d_{reduction_method.lower()}_plot.html"
        fig_3d.write_html(html_filename)
        print(f"✓ 3D Interactive plot saved to: {html_filename}")

        # Try to open in browser if show flag is set
        if args.show:
            import webbrowser
            import os

            file_path = os.path.abspath(html_filename)
            webbrowser.open(f"file://{file_path}")
            print(f"✓ Opening plot in browser...")
        else:
            print(f"  Use --show flag to automatically open in browser")
            print(f"  Or manually open: {html_filename}")

        print(f"✓ 3D Interactive {reduction_method} plot created")
        if reduction_method == "PCA":
            print(
                f"  - Total explained variance (PC1+PC2+PC3): {var_table.loc[:2, 'Explained_Var%'].sum():.1f}%"
            )
        else:  # UMAP
            print("  - UMAP preserves local structure and reveals non-linear patterns")
        print(f"  - Interactive features: rotate, zoom, pan, hover, legend toggle")

    # Create a dataloader with random sampling
    train_dataloader = DataLoader(train, batch_size=1, shuffle=True)

    print(f"Loaded training data with {len(train)} samples")

    # Create a figure for the grid plot
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.flatten()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Display training data samples")
    parser.add_argument(
        "--data_file",
        type=str,
        default="/home/ubuntu/data",
        help="Path to the data directory.",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        help="Dataset to use for training.",
    )
    parser.add_argument("--show", action="store_true", help="Show the plot")
    parser.add_argument(
        "--num_samples", type=int, default=100, help="Number of samples to show"
    )
    parser.add_argument(
        "--config_file",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--umap",
        action="store_true",
        help="Use UMAP for dimensionality reduction instead of PCA",
    )
    args = parser.parse_args()

    main(args)
