import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import pandas as pd
import seaborn as sns
from sklearn.decomposition import PCA
import os
import argparse

# Configure matplotlib to match paper style
plt.style.use('default')
plt.rcParams.update({
    'font.size': 11,
    'axes.titlesize': 11,
    'axes.labelsize': 11,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 12,
    'font.family': 'sans-serif',
    'mathtext.fontset': 'stix',
    'figure.dpi': 300,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
    'savefig.pad_inches': 0.05,
    'axes.linewidth': 0.8,
    'grid.linewidth': 0.5,
})

# Color-blind friendly palette
colors = {
    'shared': '#1f78b4',      # Blue
    'mod1': '#33a02c',        # Green  
    'mod2': '#e31a1c',        # Red
    'train': '#1f78b4',       # Blue
    'val': '#ff7f00'          # Orange
}

def load_avmnist_labels(data_dir="01_data/processed/avmnist"):
    """Load and concatenate train+test labels to match the representation order."""
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))
    all_labels = np.concatenate([train_labels, test_labels], axis=0)
    return all_labels

def plot_subspace_representation(ax, rep, labels, title, max_samples=100000, dim_reduction='pca'):
    """Plot 2D representation of a subspace, using PCA or UMAP for dimensionality reduction.
    Preserves the 1D histogram behavior when the representation is a single dimension.
    """
    # Subsample for visualization if too many points
    if rep.shape[0] > max_samples:
        indices = np.random.choice(rep.shape[0], max_samples, replace=False)
        rep_plot = rep[indices]
        labels_plot = labels[indices]
    else:
        rep_plot = rep
        labels_plot = labels
    
    # If dimension > 2, apply dimensionality reduction; if dimension == 2 use it directly
    if rep_plot.shape[1] > 2:
        if dim_reduction.lower() == 'umap':
            try:
                import umap
                reducer = umap.UMAP(n_components=2, random_state=42)
                rep_2d = reducer.fit_transform(rep_plot)
                title += f'\n(UMAP)'
            except ImportError:
                print("Warning: UMAP not available, falling back to PCA")
                pca = PCA(n_components=2)
                rep_2d = pca.fit_transform(rep_plot)
                explained_var = pca.explained_variance_ratio_
                #title += f'\n(PCA: {explained_var[0]:.2f}, {explained_var[1]:.2f})'
        else:
            pca = PCA(n_components=2)
            rep_2d = pca.fit_transform(rep_plot)
            explained_var = pca.explained_variance_ratio_
            #title += f'\n(PCA: {explained_var[0]:.2f}, {explained_var[1]:.2f})'
    elif rep_plot.shape[1] == 2:
        rep_2d = rep_plot
    try:
        from sklearn.neighbors import KNeighborsClassifier
        # Use a reasonable subset for speed
        n_samples_for_acc = min(2000, rep_2d.shape[0])
        idx = np.random.choice(rep_2d.shape[0], n_samples_for_acc, replace=False)
        X_acc = rep_2d[idx]
        y_acc = labels_plot[idx]
        # 80/20 train/test split
        split = int(0.8 * n_samples_for_acc)
        knn = KNeighborsClassifier(n_neighbors=1)
        knn.fit(X_acc[:split], y_acc[:split])
        acc = knn.score(X_acc[split:], y_acc[split:])
        title = f"{title} (Acc: {acc:.2f})"
    except Exception:
        # If something fails, don't block plotting
        pass
    #elif rep_plot.shape[1] == 1:
    #    # create a second dimension with small jitter for visualization
    #    rep_2d = np.concatenate([rep_plot, (np.random.RandomState(0).rand(rep_plot.shape[0], 1) - 0.5) * 1e-3], axis=1)
    #else:
    #    # fallback: create two small jitter dims
    #    rep_2d = np.random.RandomState(0).rand(rep_plot.shape[0], 2) * 1e-3
    
    # Create scatter plot colored by MNIST labels
    if rep_plot.shape[1] == 1:
        # plot histograms colored by digit with seaborn
        sns.histplot(x=rep_plot[:,0], hue=labels_plot, bins=30, palette='tab10', element='step', stat='density', common_norm=False, ax=ax)
        ax.set_title(title)
        ax.set_xlabel('Dim 1')
        ax.set_ylabel('Density')
        # Add legend for digit labels
        ax.legend(title='Digit', bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        scatter = ax.scatter(rep_2d[:, 0], rep_2d[:, 1], c=labels_plot, cmap='tab10', 
                            s=1, alpha=0.6, rasterized=True)
        ax.set_title(title)
        ax.set_xlabel('Dim 1')
        ax.set_ylabel('Dim 2')
        
        # Add colorbar for digit labels
        cbar = plt.colorbar(scatter, ax=ax, ticks=range(10))
        cbar.set_label('Digit', rotation=0, labelpad=15)
    
    return ax

def create_avmnist_figure(seed=42, data_dir="01_data/processed/avmnist", 
                         results_dir="03_results/models", dim_reduction='pca'):
    """Create comprehensive AVMNIST training and representation figure."""
    
    # Load data
    try:
        # Load rank history
        rank_file = os.path.join(results_dir, f"avmnist_rseed-{seed}_rank_history.csv")
        rank_history = pd.read_csv(rank_file)
        
        # Load post-training loss curves and append to rank history
        try:
            posttrain_loss_file = os.path.join(results_dir, f"avmnist_rseed-{seed}_posttrained_loss_curves.csv")
            posttrain_losses = pd.read_csv(posttrain_loss_file)
            
            # Create extended epochs starting from where original training ended
            original_max_epoch = rank_history['epoch'].max()
            posttrain_epochs = range(original_max_epoch + 1, original_max_epoch + 1 + len(posttrain_losses))
            
            # Create extended dataframe with post-training data
            posttrain_data = {
                'epoch': posttrain_epochs,
                'loss': posttrain_losses['train_loss'],
                'val_loss': posttrain_losses['val_loss']
            }
            
            # For R² and ranks, use the last values from original training (since ranks are frozen)
            last_rank_row = rank_history.iloc[-1]
            for col in ['total_rank', 'ranks']:
                if col in rank_history.columns:
                    posttrain_data[col] = [last_rank_row[col]] * len(posttrain_losses)
            
            for i in range(2):  # R² for each modality
                r2_col = f'rsquare {i}'
                if r2_col in rank_history.columns:
                    posttrain_data[r2_col] = [last_rank_row[r2_col]] * len(posttrain_losses)
            
            # Append post-training data to rank history
            posttrain_df = pd.DataFrame(posttrain_data)
            rank_history = pd.concat([rank_history, posttrain_df], ignore_index=True)
            
            print(f"Loaded and appended post-training losses ({len(posttrain_losses)} epochs)")
            
        except Exception as e:
            print(f"Warning: Could not load post-training losses: {e}")
        
        # Load post-trained representations
        reps = []
        for i in range(3):  # shared, mod1, mod2
            rep_file = os.path.join(results_dir, f"avmnist_rseed-{seed}_posttrained_rep{i}.npy")
            rep = np.load(rep_file)
            reps.append(rep)
        
        # Load labels
        labels = load_avmnist_labels(data_dir)
        
        print(f"Loaded data for seed {seed}")
        print(f"Representations shapes: {[rep.shape for rep in reps]}")
        print(f"Labels shape: {labels.shape}")
        
    except Exception as e:
        print(f"Error loading data: {e}")
        return None
    
    # Create figure with gridspec
    fig = plt.figure(figsize=(9.45, 3.72))
    gs = gridspec.GridSpec(3, 5, figure=fig, hspace=0.3, wspace=0.3)
    
    # Row 1: Ranks over epochs (colored by subspace)
    ax_ranks = fig.add_subplot(gs[0, 0:2])
    epochs = rank_history['epoch'].values
    
    # Parse individual ranks from the 'ranks' column (format: "rank1, rank2, rank3")
    rank_strings = rank_history['ranks'].values
    individual_ranks = []
    
    for rank_str in rank_strings:
        ranks = [int(x.strip()) for x in rank_str.split(',')]
        individual_ranks.append(ranks)
    
    individual_ranks = np.array(individual_ranks)
    
    # Plot individual subspace ranks
    ax_ranks.plot(epochs, individual_ranks[:, 0], color=colors['shared'], 
                 linewidth=2, label='Shared', marker='o', markersize=3)
    ax_ranks.plot(epochs, individual_ranks[:, 1], color=colors['mod1'], 
                 linewidth=2, label='Image', marker='s', markersize=3)
    ax_ranks.plot(epochs, individual_ranks[:, 2], color=colors['mod2'], 
                 linewidth=2, label='Audio', marker='^', markersize=3)
    
    #ax_ranks.set_xlabel('Epoch')
    ax_ranks.set_xlabel('')
    ax_ranks.set_ylabel('Rank')
    ax_ranks.set_title('Subspace Ranks')
    ax_ranks.legend()
    ax_ranks.grid(True, alpha=0.3)
    # Hide epoch x-axis ticks/labels for a cleaner multi-panel layout
    ax_ranks.set_xticks([])
    
    # Row 2: R² values over epochs (colored by modality)
    ax_r2 = fig.add_subplot(gs[1, 0:2])
    
    # Plot R² for each modality if available
    for i in range(2):  # 2 modalities (image, audio)
        r2_col = f'rsquare {i}'
        if r2_col in rank_history.columns:
            modality_name = 'Image' if i == 0 else 'Audio'
            color = colors['mod1'] if i == 0 else colors['mod2']
            ax_r2.plot(epochs, rank_history[r2_col], color=color, 
                      linewidth=2, label=modality_name, marker='o', markersize=3)
            # Add a dotted horizontal line at R² = 0.05 (reference threshold)
            try:
                #if i == 0:
                #    ref_val = 0.7986549854278564
                #else:
                #    ref_val = 0.8870738863945007
                ref_val = max(rank_history[r2_col]) - 0.05
                ax_r2.axhline(ref_val, color=color, linestyle=':', linewidth=1, alpha=0.8)
            except Exception:
                pass
    
    #ax_r2.set_xlabel('Epoch')
    ax_r2.set_xlabel('')
    ax_r2.set_ylabel('R²')
    ax_r2.set_title('R² by Modality')
    ax_r2.legend()
    ax_r2.grid(True, alpha=0.3)
    # Hide epoch x-axis ticks/labels on the R² subplot as well
    ax_r2.set_xticks([])
    
    # Row 3: Train and validation loss
    ax_loss = fig.add_subplot(gs[2, 0:2])
    ax_loss.plot(epochs, rank_history['loss'], color=colors['train'], 
                linewidth=2, label='Train', marker='o', markersize=2)
    ax_loss.plot(epochs, rank_history['val_loss'], color=colors['val'], 
                linewidth=2, label='Validation', marker='s', markersize=2)
    
    ax_loss.set_xlabel('Epoch')
    ax_loss.set_ylabel('Loss (log scale)')
    ax_loss.set_yscale('log')
    ax_loss.set_title('Training & Validation Loss')
    ax_loss.legend()
    ax_loss.grid(True, alpha=0.3)
    
    # Columns 3-5: Subspace representations spanning all rows
    subspace_names = ['Shared', 'Image', 'Audio']
    n_train = 60000  # AVMNIST training size
    
    for i, (rep, name) in enumerate(zip(reps, subspace_names)):
        # Span all 3 rows for larger, more detailed visualization
        ax_rep = fig.add_subplot(gs[:, i+2])  # All rows, column i+2
        plot_subspace_representation(ax_rep, rep[:n_train], labels[:n_train], 
                                   name, max_samples=3000, dim_reduction=dim_reduction)
    
    # Add main title
    #fig.suptitle(f'AVMNIST Multimodal Autoencoder Analysis - Post-Trained (Seed {seed})', 
    #            fontsize=14, y=0.95)
    
    # Save figure
    output_dir = "03_results/plots"
    os.makedirs(output_dir, exist_ok=True)
    output_file = os.path.join(output_dir, f'avmnist_analysis_posttrained_seed{seed}.png')
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"Figure saved to: {output_file}")
    
    return fig

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Create AVMNIST analysis figure")
    parser.add_argument('--seed', type=int, default=42, 
                       help='Random seed for which results to plot')
    parser.add_argument('--data_dir', type=str, default="01_data/processed/avmnist",
                       help='Directory containing AVMNIST data')
    parser.add_argument('--results_dir', type=str, default="03_results/models",
                       help='Directory containing saved model results')
    parser.add_argument('--dim_reduction', type=str, default='pca', choices=['pca', 'umap'],
                       help='Dimensionality reduction method for visualization (default: pca)')
    args = parser.parse_args()
    
    fig = create_avmnist_figure(seed=args.seed, data_dir=args.data_dir, 
                               results_dir=args.results_dir, dim_reduction=args.dim_reduction)
    
    if fig is not None:
        plt.show()
