import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
import argparse
import os
from torchvision.datasets import MNIST
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

# Configure matplotlib to match paper style
plt.style.use('default')
plt.rcParams.update({
    'font.size': 11,
    'axes.titlesize': 11,
    'axes.labelsize': 11,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 12,
    'font.family': 'sans-serif',
    'mathtext.fontset': 'stix',
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.05,
    'axes.linewidth': 0.8,
    'grid.linewidth': 0.5,
})

# Color-blind friendly palette
colors = {
    'shared': '#1f78b4',      # Blue
    'mod1': '#33a02c',        # Green  
    'mod2': '#e31a1c',        # Red
    'train': '#1f78b4',       # Blue
    'val': '#ff7f00'          # Orange
}

def load_avmnist_labels(data_dir="01_data/processed/avmnist"):
    """Load and concatenate train+test labels to match the representation order."""
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))
    all_labels = np.concatenate([train_labels, test_labels], axis=0)
    return all_labels

def plot_subspace_representation(ax, rep, labels, title, max_samples=100000, dim_reduction='pca', images=None, show_images=False, image_zoom=0.05):
    """Plot 2D representation of a subspace, using PCA or UMAP for dimensionality reduction.
    Preserves the 1D histogram behavior when the representation is a single dimension.
    If images provided and show_images=True, displays image thumbnails instead of dots.
    """
    # Subsample for visualization if too many points
    if rep.shape[0] > max_samples:
        indices = np.random.choice(rep.shape[0], max_samples, replace=False)
        rep_plot = rep[indices]
        labels_plot = labels[indices]
    else:
        indices = np.arange(rep.shape[0])  # All indices when not subsampling
        rep_plot = rep
        labels_plot = labels
    
    # If dimension > 2, apply dimensionality reduction; if dimension == 2 use it directly
    if rep_plot.shape[1] > 2:
        if dim_reduction.lower() == 'umap':
            try:
                import umap
                reducer = umap.UMAP(n_components=2, random_state=42)
                rep_2d = reducer.fit_transform(rep_plot)
                title += f'\n(UMAP)'
            except ImportError:
                print("Warning: UMAP not available, falling back to PCA")
                pca = PCA(n_components=2)
                rep_2d = pca.fit_transform(rep_plot)
                explained_var = pca.explained_variance_ratio_
                #title += f'\n(PCA: {explained_var[0]:.2f}, {explained_var[1]:.2f})'
        else:
            pca = PCA(n_components=2)
            rep_2d = pca.fit_transform(rep_plot)
            explained_var = pca.explained_variance_ratio_
            #title += f'\n(PCA: {explained_var[0]:.2f}, {explained_var[1]:.2f})'
    elif rep_plot.shape[1] == 2:
        rep_2d = rep_plot
    try:
        from sklearn.neighbors import KNeighborsClassifier
        # Use a reasonable subset for speed
        n_samples_for_acc = rep_plot.shape[0]
        #idx = np.random.choice(rep_plot.shape[0], n_samples_for_acc, replace=False)
        #X_acc = rep_2d[idx]
        #y_acc = labels_plot[idx]
        X_acc = rep_plot
        y_acc = labels_plot
        # 80/20 train/test split
        split = int(0.8 * n_samples_for_acc)
        knn = KNeighborsClassifier(n_neighbors=1)
        knn.fit(X_acc[:split], y_acc[:split])
        acc = knn.score(X_acc[split:], y_acc[split:])
        title = f"{title} (Acc: {acc:.2f})"
    except Exception:
        # If something fails, don't block plotting
        pass
    #elif rep_plot.shape[1] == 1:
    #    # create a second dimension with small jitter for visualization
    #    rep_2d = np.concatenate([rep_plot, (np.random.RandomState(0).rand(rep_plot.shape[0], 1) - 0.5) * 1e-3], axis=1)
    #else:
    #    # fallback: create two small jitter dims
    #    rep_2d = np.random.RandomState(0).rand(rep_plot.shape[0], 2) * 1e-3
    
    # Create scatter plot colored by MNIST labels
    if rep_plot.shape[1] == 1:
        # plot histograms colored by digit with seaborn
        sns.histplot(x=rep_plot[:,0], hue=labels_plot, bins=30, palette='tab10', element='step', stat='density', common_norm=False, ax=ax)
        ax.set_title(title)
        ax.set_xlabel('Dim 1')
        ax.set_ylabel('Density')
        # Add legend for digit labels
        ax.legend(title='Digit', bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        if show_images and images is not None and len(images) >= len(rep_2d):
            print(f"DEBUG: Using image thumbnails for {title} - images shape: {images.shape}, rep_2d shape: {rep_2d.shape}")
            # Use image thumbnails instead of dots
            # Subsample images to match the subsampled representations
            if rep.shape[0] > max_samples:
                images_plot = images[indices]
            else:
                images_plot = images[:len(rep_2d)]
            
            # Create invisible scatter to set up axes properly
            ax.scatter(rep_2d[:, 0], rep_2d[:, 1], alpha=0)
            
            # Add image thumbnails - subsample further for performance
            n_thumbnails = min(1000, len(rep_2d))  # Limit thumbnails for performance
            thumb_indices = np.random.choice(len(rep_2d), n_thumbnails, replace=False)
            print(f"DEBUG: Adding {n_thumbnails} thumbnails")
            
            for idx in thumb_indices:
                # Create OffsetImage from the image
                img = images_plot[idx]
                if img.ndim == 3 and img.shape[0] == 1:  # (1, H, W) -> (H, W)
                    img = img[0]
                imagebox = OffsetImage(img, zoom=image_zoom, cmap='gray')
                ab = AnnotationBbox(imagebox, (rep_2d[idx, 0], rep_2d[idx, 1]), 
                                  frameon=False, pad=0)
                ax.add_artist(ab)
        else:
            print(f"DEBUG: Using standard scatter for {title} - show_images: {show_images}, images is None: {images is None}")
            # Standard scatter plot
            scatter = ax.scatter(rep_2d[:, 0], rep_2d[:, 1], c=labels_plot, cmap='tab10', 
                                s=1, alpha=0.6, rasterized=True)
            # Add colorbar for digit labels
            cbar = plt.colorbar(scatter, ax=ax, ticks=range(10))
            cbar.set_label('Digit', rotation=0, labelpad=15)
        
        ax.set_title(title)
        ax.set_xlabel('Dim 1')
        ax.set_ylabel('Dim 2')
    
    return ax

def create_avmnist_figure(seed=42, data_dir="01_data/processed/avmnist", 
                         results_dir="03_results/models", dim_reduction='pca', full_spectrum=False):
    """Create comprehensive AVMNIST training and representation figure using real MNIST data."""
    
    # Create model prefix for consistent file naming
    model_prefix = f"avmnist_real{'_fullspec' if full_spectrum else ''}_rseed-{seed}"
    
    # Load MNIST images for thumbnail scatter plotting
    try:
        mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=False)
        mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=False)
        
        # Get MNIST images as numpy arrays, normalized to [0,1]
        mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0
        mnist_train_labels = mnist_train.targets.numpy()
        mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
        mnist_test_labels = mnist_test.targets.numpy()
        
        # Concatenate train+test to match AVMNIST order
        all_mnist_images = np.concatenate([mnist_train_images, mnist_test_images], axis=0)
        all_mnist_labels = np.concatenate([mnist_train_labels, mnist_test_labels], axis=0)
        
        # Load AVMNIST labels to match ordering
        avmnist_labels = load_avmnist_labels(data_dir)
        
        # Load saved MNIST-AVMNIST mapping - REQUIRED for reproducibility
        mapping_file = f"03_results/models/{model_prefix}_mnist_mapping.npz"
        if not os.path.exists(mapping_file):
            print(f"ERROR: Required MNIST-AVMNIST mapping file not found: {mapping_file}")
            print("Please run 054_avmnist_real.py first to create the mapping.")
            raise FileNotFoundError(f"Mapping file required for reproducibility: {mapping_file}")
        
        print(f"Loading saved MNIST-AVMNIST mapping from: {mapping_file}")
        mapping_data = np.load(mapping_file, allow_pickle=True)

        # Build combined MNIST arrays
        all_mnist_images_combined = np.concatenate([mnist_train_images, mnist_test_images], axis=0)
        all_mnist_labels_combined = np.concatenate([mnist_train_labels, mnist_test_labels], axis=0)

        # Support both formats: older mapping used 'mnist_image_indices' (train-only indices),
        # newer mapping may contain 'mnist_train_indices' and 'mnist_test_indices'.
        if 'mnist_train_indices' in mapping_data and 'mnist_test_indices' in mapping_data:
            train_indices = np.array(mapping_data['mnist_train_indices'], dtype=np.int64)
            test_indices = np.array(mapping_data['mnist_test_indices'], dtype=np.int64)
            # Validate bounds
            if train_indices.max() >= all_mnist_images_combined.shape[0] or train_indices.min() < 0:
                raise ValueError('mnist_train_indices out of bounds for combined MNIST array')
            if test_indices.max() >= all_mnist_images_combined.shape[0] or test_indices.min() < 0:
                raise ValueError('mnist_test_indices out of bounds for combined MNIST array')
            all_indices = list(train_indices) + list(test_indices)
            aligned_images = all_mnist_images_combined[all_indices]
            print(f"Loaded per-split train/test MNIST mapping and aligned {len(aligned_images)} images to AVMNIST labels")
        elif 'mnist_image_indices' in mapping_data:
            mnist_indices = np.array(mapping_data['mnist_image_indices'], dtype=np.int64)
            # Use the exact same MNIST images as training (train-only indices)
            train_size = len(mnist_indices)

            # Create test mapping for remaining AVMNIST samples deterministically
            test_avmnist_labels = avmnist_labels[train_size:]
            test_mnist_indices = []
            test_label_counters = {lab: 0 for lab in range(10)}
            test_label_indices = {lab: np.where(mnist_test_labels == lab)[0] + len(mnist_train_labels) for lab in range(10)}

            for test_label in test_avmnist_labels:
                matching_indices = test_label_indices.get(int(test_label), np.array([]))
                if matching_indices.size > 0:
                    cnt = test_label_counters[int(test_label)]
                    idx = matching_indices[cnt % matching_indices.size]
                    test_mnist_indices.append(idx)
                    test_label_counters[int(test_label)] += 1
                else:
                    # Use first available index if no match
                    test_mnist_indices.append(len(mnist_train_labels))

            # Combine train and test indices
            all_indices = list(mnist_indices) + test_mnist_indices
            # Validate bounds
            all_indices_arr = np.array(all_indices, dtype=np.int64)
            if all_indices_arr.max() >= all_mnist_images_combined.shape[0] or all_indices_arr.min() < 0:
                raise ValueError('Combined MNIST mapping indices out of bounds for combined MNIST array')
            aligned_images = all_mnist_images_combined[all_indices]
            print(f"Used saved mapping for reproducible image alignment (train-only mapping with deterministic test mapping)")
            print(f"Aligned {len(aligned_images)} MNIST images to AVMNIST labels")
            print(f"DEBUG: aligned_images shape: {aligned_images.shape}, min: {aligned_images.min():.3f}, max: {aligned_images.max():.3f}")
        else:
            raise ValueError(f"Mapping file {mapping_file} missing expected keys for mapping ('mnist_train_indices' or 'mnist_image_indices').")
    except Exception as e:
        print(f"Warning: Could not load MNIST images for thumbnails: {e}")
        aligned_images = None
        print(f"DEBUG: aligned_images set to None due to error")
    
    # Load data
    try:
        # Load rank history
        rank_file = os.path.join(results_dir, f"{model_prefix}_rank_history.csv")
        rank_history = pd.read_csv(rank_file)
        
        # Load post-training loss curves and append to rank history
        try:
            posttrain_loss_file = os.path.join(results_dir, f"{model_prefix}_posttrain_losses.csv")
            posttrain_losses = pd.read_csv(posttrain_loss_file)
            
            # Create extended epochs starting from where original training ended
            original_max_epoch = rank_history['epoch'].max()
            posttrain_epochs = range(original_max_epoch + 1, original_max_epoch + 1 + len(posttrain_losses))
            
            # Create extended dataframe with post-training data
            posttrain_data = {
                'epoch': posttrain_epochs,
                'loss': posttrain_losses['train_loss'],
                'val_loss': posttrain_losses['val_loss']
            }
            
            # For R² and ranks, use the last values from original training (since ranks are frozen)
            last_rank_row = rank_history.iloc[-1]
            for col in ['total_rank', 'ranks']:
                if col in rank_history.columns:
                    posttrain_data[col] = [last_rank_row[col]] * len(posttrain_losses)
            
            for i in range(2):  # R² for each modality
                r2_col = f'rsquare {i}'
                if r2_col in rank_history.columns:
                    posttrain_data[r2_col] = [last_rank_row[r2_col]] * len(posttrain_losses)
            
            # Append post-training data to rank history
            posttrain_df = pd.DataFrame(posttrain_data)
            rank_history = pd.concat([rank_history, posttrain_df], ignore_index=True)
            
            print(f"Loaded and appended post-training losses ({len(posttrain_losses)} epochs)")
            
        except Exception as e:
            print(f"Warning: Could not load post-training losses: {e}")
        
        # Load post-trained representations
        reps = []
        for i in range(3):  # shared, mod1, mod2
            rep_file = os.path.join(results_dir, f"{model_prefix}_posttrain_rep{i}.npy")
            rep = np.load(rep_file)
            reps.append(rep)
        
        # Load labels
        labels = load_avmnist_labels(data_dir)
        
        print(f"Loaded data for seed {seed}")
        print(f"Representations shapes: {[rep.shape for rep in reps]}")
        print(f"Labels shape: {labels.shape}")
        # Compute per-dimension classification accuracy (1-NN) for each representation
        try:
            print("Computing single-dimension 1-NN accuracies for each subspace (this may take a moment)...")
            single_dim_results = {}
            # number of training samples to consider (match labels length)
            n_train = labels.shape[0]
            for idx, rep in enumerate(reps):
                # Ensure we only use samples available in both rep and labels
                n_train_rep = min(labels.shape[0], rep.shape[0])
                X = rep[:n_train_rep]
                y = labels[:n_train_rep]
                # coerce labels to integer dtype
                try:
                    y = y.astype(int)
                except Exception:
                    y = np.array([int(v) for v in y])

                n_dims = X.shape[1]
                accs = np.zeros(n_dims)
                # For speed, convert to float32
                X = X.astype(np.float32)
                # Train/test split; ensure at least one sample in each split
                n_samples_rep = X.shape[0]
                split = int(0.8 * n_samples_rep)
                if split < 1 or (n_samples_rep - split) < 1:
                    # Not enough samples to form train/test split; mark all as NaN
                    accs[:] = np.nan
                else:
                    for d in range(n_dims):
                        Xd = X[:, d].reshape(-1, 1)
                        try:
                            knn = KNeighborsClassifier(n_neighbors=1)
                            knn.fit(Xd[:split], y[:split])
                            acc = knn.score(Xd[split:], y[split:])
                        except Exception:
                            acc = np.nan
                        accs[d] = acc
                single_dim_results[idx] = accs
                # Print summary statistics
                valid_accs = accs[np.isfinite(accs)]
                mean_acc = np.nanmean(valid_accs)
                max_idx = int(np.nanargmax(accs)) if np.any(np.isfinite(accs)) else None
                max_acc = accs[max_idx] if max_idx is not None else np.nan
                print(f"Subspace {idx} - dims: {n_dims}, mean single-dim acc: {mean_acc:.4f}, max acc: {max_acc:.4f} (dim {max_idx})")
                # Save per-dim accuracies to CSV
                df_acc = pd.DataFrame({'dim': np.arange(n_dims), 'accuracy': accs})
                acc_file = os.path.join(results_dir, f'{model_prefix}_single_dim_acc_subspace{idx}.csv')
                df_acc.to_csv(acc_file, index=False)
                print(f"Saved per-dimension accuracies to: {acc_file}")
        except Exception as e:
            print(f"Warning: could not compute single-dimension accuracies: {e}")
        
    except Exception as e:
        print(f"Error loading data: {e}")
        return None
    
    # Create figure with gridspec (use 2 rows now that loss plot is removed)
    fig = plt.figure(figsize=(19.45, 3.72))
    gs = gridspec.GridSpec(2, 5, figure=fig, hspace=0.3, wspace=0.3)
    
    # Row 1: Ranks over epochs (colored by subspace)
    ax_ranks = fig.add_subplot(gs[0, 0:2])
    epochs = rank_history['epoch'].values
    
    # Parse individual ranks from the 'ranks' column (format: "rank1, rank2, rank3")
    rank_strings = rank_history['ranks'].values
    individual_ranks = []
    
    for rank_str in rank_strings:
        ranks = [int(x.strip()) for x in rank_str.split(',')]
        individual_ranks.append(ranks)
    
    individual_ranks = np.array(individual_ranks)
    
    # Plot individual subspace ranks
    ax_ranks.plot(epochs, individual_ranks[:, 0], color=colors['shared'], 
                 linewidth=2, label='Shared', marker='o', markersize=3)
    ax_ranks.plot(epochs, individual_ranks[:, 1], color=colors['mod1'], 
                 linewidth=2, label='Image', marker='s', markersize=3)
    ax_ranks.plot(epochs, individual_ranks[:, 2], color=colors['mod2'], 
                 linewidth=2, label='Audio', marker='^', markersize=3)
    
    #ax_ranks.set_xlabel('Epoch')
    ax_ranks.set_xlabel('')
    ax_ranks.set_ylabel('Rank')
    ax_ranks.set_title('Subspace Ranks')
    ax_ranks.legend()
    ax_ranks.grid(True, alpha=0.3)
    # Hide epoch x-axis ticks/labels for a cleaner multi-panel layout
    ax_ranks.set_xticks([])
    
    # Row 2: R² values over epochs (colored by modality)
    ax_r2 = fig.add_subplot(gs[1, 0:2])
    
    # Plot R² for each modality if available
    for i in range(2):  # 2 modalities (image, audio)
        r2_col = f'rsquare {i}'
        if r2_col in rank_history.columns:
            modality_name = 'Image' if i == 0 else 'Audio'
            color = colors['mod1'] if i == 0 else colors['mod2']
            ax_r2.plot(epochs, rank_history[r2_col], color=color, 
                      linewidth=2, label=modality_name, marker='o', markersize=3)
            # Add a dotted horizontal line at R² = 0.05 (reference threshold)
            try:
                #if i == 0:
                #    ref_val = 0.7986549854278564
                #else:
                #    ref_val = 0.8870738863945007
                ref_val = max(rank_history[r2_col]) - 0.05
                ax_r2.axhline(ref_val, color=color, linestyle=':', linewidth=1, alpha=0.8)
            except Exception:
                pass
    
    #ax_r2.set_xlabel('Epoch')
    ax_r2.set_xlabel('')
    ax_r2.set_ylabel('R²')
    ax_r2.set_title('R² by Modality')
    ax_r2.legend()
    ax_r2.grid(True, alpha=0.3)
    # Hide epoch x-axis ticks/labels on the R² subplot as well
    ax_r2.set_xticks([])
    
    # Columns 2-5: Subspace representations spanning both rows
    subspace_names = ['Shared', 'Image', 'Audio']
    n_train = 60000  # AVMNIST training size
    
    for i, (rep, name) in enumerate(zip(reps, subspace_names)):
        # Span both rows for larger, more detailed visualization
        ax_rep = fig.add_subplot(gs[:, i+2])  # All rows, column i+2

        print(f"DEBUG: Processing subspace {i}: '{name}', aligned_images is None: {aligned_images is None}")

        # Use thumbnails for multi-dimensional subspaces, keep 1D as histograms
        if aligned_images is not None and rep.shape[1] > 1:  # Only for multi-dimensional
            print(f"DEBUG: Using thumbnails for multi-dimensional '{name}' subspace")
            plot_subspace_representation(ax_rep, rep[:n_train], labels[:n_train], 
                                       name, max_samples=3000, dim_reduction=dim_reduction,
                                       images=aligned_images[:n_train], show_images=True, image_zoom=0.3)
        else:
            if rep.shape[1] == 1:
                print(f"DEBUG: Using histogram for 1D '{name}' subspace")
            else:
                print(f"DEBUG: Using standard scatter for '{name}' subspace (no images available)")
            plot_subspace_representation(ax_rep, rep[:n_train], labels[:n_train], 
                                       name, max_samples=3000, dim_reduction=dim_reduction)
    
    # Add main title
    #fig.suptitle(f'AVMNIST Multimodal Autoencoder Analysis - Post-Trained (Seed {seed})', 
    #            fontsize=14, y=0.95)
    
    # Save main figure
    output_dir = "03_results/plots"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f'avmnist_real{'_fullspec' if full_spectrum else ''}_analysis_posttrained_seed{seed}.png')
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"Figure saved to: {output_file}")

    # Additional plot: label-sorted heatmap of the first representation
    try:
        rep0 = reps[0]
        # Use only n_train samples and ensure shape (n_samples, n_dims)
        rep0 = rep0[:n_train]
        labels_trunc = labels[:n_train]

        # Sort samples by label to group digits together
        order = np.argsort(labels_trunc)
        rep0_sorted = rep0[order]
        labels_sorted = labels_trunc[order]

        # Select number of dimensions to display (limit for readability)
        max_dims = 100
        n_dims = rep0_sorted.shape[1]
        display_dims = min(n_dims, max_dims)
        # data_heat shape: (n_samples, display_dims) so samples will be rows (y-axis)
        data_heat = rep0_sorted[:, :display_dims]

        # Clip values for visualization (1st-99th percentile)
        vmin, vmax = np.percentile(data_heat, [1, 99])

        fig_h, ax_h = plt.subplots(1, 1, figsize=(6, 10))
        im = ax_h.imshow(data_heat, aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax, interpolation='nearest')
        ax_h.set_xlabel(f'Representation dims (0..{display_dims-1})')
        ax_h.set_ylabel('Samples (sorted by label)')
        ax_h.set_title('Shared subspace (rep0) heatmap — samples sorted by label')
        cbar = fig_h.colorbar(im, ax=ax_h)

        # Draw horizontal lines separating labels and set ticks at group centers
        unique_labels, counts = np.unique(labels_sorted, return_counts=True)
        # compute start indices for each label in the sorted order
        starts = np.cumsum(np.concatenate([[0], counts[:-1]]))
        centers = starts + counts / 2.0 - 0.5
        # add tick for each unique label at its center
        ax_h.set_yticks(centers)
        ax_h.set_yticklabels([str(int(l)) for l in unique_labels])
        # draw separators
        for s in starts[1:]:
            ax_h.axhline(s - 0.5, color='white', linewidth=0.6, alpha=0.7)

        heat_output = os.path.join(output_dir, f'avmnist_real{'_fullspec' if full_spectrum else ''}_rep0_heatmap_by_label_seed{seed}.png')
        fig_h.savefig(heat_output, dpi=200, bbox_inches='tight')
        plt.close(fig_h)
        print(f"Saved rep0 label-sorted heatmap to: {heat_output}")
    except Exception as e:
        print(f"Warning: could not create rep0 heatmap: {e}")

    # Create and save a sample grid of the first 20 training images using real MNIST images
    sample_fig = None
    try:
        # Load MNIST (train + test) so we can pick real digit images matching the AVMNIST labels
        mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
        mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
        imgs_train = mnist_train.data.numpy().astype('float32') / 255.0
        imgs_test = mnist_test.data.numpy().astype('float32') / 255.0
        mnist_labels_train = mnist_train.targets.numpy()
        mnist_labels_test = mnist_test.targets.numpy()
        imgs_all = np.concatenate([imgs_train, imgs_test], axis=0)
        mnist_labels_all = np.concatenate([mnist_labels_train, mnist_labels_test], axis=0)

        # For each label in the AVMNIST labels (in order), pick a representative MNIST image with the same digit
        target_labels = labels[:20]
        selected_imgs = []
        used_idxs = set()
        for t in target_labels:
            # find first unused MNIST index with matching label
            candidates = np.where(mnist_labels_all == int(t))[0]
            chosen = None
            for c in candidates:
                if int(c) not in used_idxs:
                    chosen = int(c)
                    used_idxs.add(chosen)
                    break
            if chosen is None:
                # fallback: pick a random one
                if len(candidates) > 0:
                    chosen = int(candidates[0])
                else:
                    chosen = 0
            selected_imgs.append(imgs_all[chosen])

        imgs_np = np.stack(selected_imgs, axis=0)
        n_show = min(20, imgs_np.shape[0])
        cols = 5
        rows = int(np.ceil(n_show / cols))
        sample_fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.6, rows * 1.6))
        axes = np.array(axes).reshape(-1)
        for i in range(n_show):
            axes[i].imshow(imgs_np[i], cmap='gray')
            axes[i].axis('off')
            try:
                axes[i].set_title(str(int(target_labels[i])), fontsize=8)
            except Exception:
                axes[i].set_title(str(target_labels[i]), fontsize=8)
        for j in range(n_show, len(axes)):
            axes[j].axis('off')
        plt.tight_layout()
        sample_output = os.path.join(output_dir, f'avmnist_real{'_fullspec' if full_spectrum else ''}_first20_mnist_seed{seed}.png')
        sample_fig.savefig(sample_output, dpi=150, bbox_inches='tight')
        plt.close(sample_fig)
        print(f"Saved sample MNIST image grid to: {sample_output}")
    except Exception as e:
        print(f"Warning: could not create sample MNIST image grid: {e}")

    return fig, sample_fig

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Create AVMNIST analysis figure with real MNIST data")
    parser.add_argument('--seed', type=int, default=42, 
                       help='Random seed for which results to plot')
    parser.add_argument('--data_dir', type=str, default="01_data/processed/avmnist",
                       help='Directory containing AVMNIST data')
    parser.add_argument('--results_dir', type=str, default="03_results/models",
                       help='Directory containing saved model results')
    parser.add_argument('--dim_reduction', type=str, default='pca', choices=['pca', 'umap'],
                       help='Dimensionality reduction method for visualization (default: pca)')
    parser.add_argument('--full_spectrum', action='store_true',
                       help='Use full spectrum (112x112) audio data instead of averaged (112)')
    args = parser.parse_args()
    
    res = create_avmnist_figure(seed=args.seed, data_dir=args.data_dir, 
                               results_dir=args.results_dir, dim_reduction=args.dim_reduction,
                               full_spectrum=args.full_spectrum)
    if res is None:
        pass
    else:
        # Res may be (fig, sample_fig) or a single fig for backward compatibility
        if isinstance(res, tuple) or isinstance(res, list):
            main_fig = res[0]
            # sample_fig = res[1]  # already saved to disk
        else:
            main_fig = res
        if main_fig is not None:
            plt.show()
