#!/usr/bin/env python3
"""
Compute digit classification accuracy for AudioMNIST representations.

This script:
1. Loads pretrained image and audio autoencoders
2. Extracts representations from each modality
3. Computes PPD decomposition
4. Evaluates classification accuracy for:
   - Individual modality representations
   - Concatenated representations
   - Raw image data
   - PPD subspaces (joint_X, joint_Y, individual_X, individual_Y)
"""

import numpy as np
import torch
import pandas as pd
import os
import sys
import argparse
import csv
from torchvision.datasets import MNIST
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import cross_val_score
from pathlib import Path

# Add src to path
project_root = Path(__file__).parent.parent.parent.absolute()
sys.path.append(str(project_root))

from src.functions.pretrain_avmnist import ImageAutoencoder, AudioAutoencoder
from src.models.mm_baselines import PPD, JIVE, AJIVE, SLIDE, ShIndICA


def match_datasets_by_label(images, img_labels, audio, audio_labels, speaker_labels, 
                           n_samples_per_class=6000, seed=42):
    """
    Matches MNIST images to AudioMNIST audio, ensuring speaker labels stay aligned.
    """
    np.random.seed(seed)
    matched_images = []
    matched_audio = []
    matched_labels = []
    matched_speaker_labels = []
    
    for digit in range(10):
        img_indices = np.where(img_labels == digit)[0]
        audio_indices = np.where(audio_labels == digit)[0]
        
        n_samples = min(len(img_indices), len(audio_indices), n_samples_per_class)
        
        if n_samples > 0:
            selected_img_indices = np.random.choice(img_indices, n_samples, replace=False)
            selected_audio_indices = np.random.choice(audio_indices, n_samples, replace=False)
            
            matched_images.append(images[selected_img_indices])
            matched_audio.append(audio[selected_audio_indices])
            matched_labels.extend([digit] * n_samples)
            matched_speaker_labels.append(speaker_labels[selected_audio_indices])
            
    return (np.concatenate(matched_images, axis=0),
            np.concatenate(matched_audio, axis=0),
            np.array(matched_labels),
            np.concatenate(matched_speaker_labels, axis=0))


def compute_knn_acc(rep, labels, n_neighbors=1):
    """Compute k-NN classification accuracy using 80/20 train/test split."""
    if rep is None or labels is None:
        return None
    try:
        X = rep
        y = np.asarray(labels)
        # If labels are strings or non-numeric, encode them to integers
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y)
        n = X.shape[0]
        if n < 5:
            return None
        split = int(0.8 * n)
        knn = KNeighborsClassifier(n_neighbors=n_neighbors)
        knn.fit(X[:split], y[:split])
        acc = knn.score(X[split:], y[split:])
        return float(acc)
    except Exception as e:
        print(f"KNN accuracy computation failed: {e}")
        return None


def compute_multi_classifier_accs(rep, labels):
    """Compute accuracies for classifiers using 5-fold cross-validation.
    Returns a dict with keys: knn, logistic, gnb (values are floats or None).
    """
    out = {'knn': None, 'logistic': None, 'gnb': None}
    if rep is None or labels is None:
        return out
    try:
        X = np.asarray(rep)
        y = np.asarray(labels)
        # encode non-numeric labels
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 10:
            return out
        
        # Use 5-fold cross-validation for all classifiers
        cv_folds = 5
        
        # KNN with 5-fold CV
        try:
            knn = KNeighborsClassifier(n_neighbors=1)
            scores = cross_val_score(knn, X, y, cv=cv_folds, scoring='accuracy')
            out['knn'] = float(np.mean(scores))
        except Exception as e:
            print(f"KNN CV failed: {e}")
            out['knn'] = None
        
        # Logistic Regression with 5-fold CV (with scaling)
        try:
            from sklearn.pipeline import Pipeline
            scaler = StandardScaler()
            log = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='auto')
            pipeline = Pipeline([('scaler', scaler), ('classifier', log)])
            scores = cross_val_score(pipeline, X, y, cv=cv_folds, scoring='accuracy')
            out['logistic'] = float(np.mean(scores))
        except Exception as e:
            print(f"Logistic CV failed: {e}")
            out['logistic'] = None
        
        # Gaussian NB with 5-fold CV (with scaling)
        try:
            from sklearn.pipeline import Pipeline
            scaler = StandardScaler()
            gnb = GaussianNB()
            pipeline = Pipeline([('scaler', scaler), ('classifier', gnb)])
            scores = cross_val_score(pipeline, X, y, cv=cv_folds, scoring='accuracy')
            out['gnb'] = float(np.mean(scores))
        except Exception as e:
            print(f"GaussianNB CV failed: {e}")
            out['gnb'] = None
        
        return out
    except Exception as e:
        print(f"Multi-classifier accuracy failed: {e}")
        return out


def main():
    parser = argparse.ArgumentParser(description="Compute digit classification accuracy for AudioMNIST")
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use')
    parser.add_argument('--full_spectrum', action='store_true', 
                       help='Use full 112x112 audio spectrogram instead of averaging to 112')
    parser.add_argument('--pretrained_dir', type=str, default='03_results/models/pretrained_unimodal',
                       help='Directory containing pretrained models')
    parser.add_argument('--data_dir', type=str, default='01_data/avmnist_data_from_source',
                       help='Directory containing AudioMNIST data')
    parser.add_argument('--train_samples_per_digit', type=int, default=5000)
    parser.add_argument('--test_samples_per_digit', type=int, default=1000)
    args = parser.parse_args()

    # Setup device
    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {DEVICE}")

    # --- 1. Load Data ---
    print("\n" + "="*80)
    print("Loading data from AudioMNIST source...")
    print("="*80)
    
    # Load real MNIST images
    mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
    mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
    
    mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0
    mnist_train_labels = mnist_train.targets.numpy()
    mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
    mnist_test_labels = mnist_test.targets.numpy()
    
    # Load audio data and speaker labels
    train_audio = np.load(os.path.join(args.data_dir, "audio/train_data.npy"))
    train_labels = np.load(os.path.join(args.data_dir, "train_labels.npy"))
    train_speaker_labels = np.load(os.path.join(args.data_dir, "train_speaker_labels.npy"))
    
    test_audio = np.load(os.path.join(args.data_dir, "audio/test_data.npy"))
    test_labels = np.load(os.path.join(args.data_dir, "test_labels.npy"))
    test_speaker_labels = np.load(os.path.join(args.data_dir, "test_speaker_labels.npy"))
    
    print(f"MNIST train images shape: {mnist_train_images.shape}")
    print(f"Audio train shape: {train_audio.shape}")

    # --- 2. Match Datasets ---
    print("\nMatching training datasets...")
    train_images, train_audio_matched, train_labels_matched, train_speaker_labels = match_datasets_by_label(
        mnist_train_images, mnist_train_labels, train_audio, train_labels, train_speaker_labels,
        n_samples_per_class=args.train_samples_per_digit, seed=args.seed
    )
    
    print("\nMatching test datasets...")
    test_images, test_audio_matched, test_labels_matched, test_speaker_labels = match_datasets_by_label(
        mnist_test_images, mnist_test_labels, test_audio, test_labels, test_speaker_labels,
        n_samples_per_class=args.test_samples_per_digit, seed=args.seed
    )

    # --- 3. Prepare Data ---
    train_images_flat = train_images.reshape(train_images.shape[0], -1)
    test_images_flat = test_images.reshape(test_images.shape[0], -1)

    # Audio was trained with full_spectrum=True, so always use that format
    train_audio_processed = train_audio_matched
    test_audio_processed = test_audio_matched

    print(f"\nFinal dataset shapes:")
    print(f"Train Images: {train_images_flat.shape}")
    print(f"Train Audio: {train_audio_processed.shape}")
    print(f"Train Labels: {train_labels_matched.shape}")

    # Concatenate train and test for label alignment
    all_images_flat = np.concatenate([train_images_flat, test_images_flat], axis=0)
    all_audio = np.concatenate([train_audio_processed, test_audio_processed], axis=0)
    all_labels = np.concatenate([train_labels_matched, test_labels_matched], axis=0)
    n_train = train_images_flat.shape[0]

    # --- 4. Load Pretrained Encoders ---
    print("\n" + "="*80)
    print("Loading pretrained encoders...")
    print("="*80)
    
    # Image autoencoder doesn't use full_spectrum flag
    # Audio was always trained with full_spectrum, so always use _fullspec suffix
    img_model_prefix = f"audiomnist_rseed-{args.seed}"
    audio_model_prefix = f"audiomnist_fullspec_rseed-{args.seed}"
    
    # Load image autoencoder
    img_ae_path = os.path.join(args.pretrained_dir, f"{img_model_prefix}_image_ae.pth")
    if not os.path.exists(img_ae_path):
        raise FileNotFoundError(f"Image autoencoder not found: {img_ae_path}")
    
    img_checkpoint = torch.load(img_ae_path, map_location='cpu')
    img_ae = ImageAutoencoder(latent_dim=img_checkpoint['latent_dim'])
    img_ae.encoder.load_state_dict(img_checkpoint['encoder_state_dict'])
    img_ae.decoder.load_state_dict(img_checkpoint['decoder_state_dict'])
    img_ae.to(DEVICE)
    img_ae.eval()
    print(f"✓ Loaded image autoencoder from: {img_ae_path}")
    print(f"  Image AE latent dim: {img_checkpoint['latent_dim']}")
    print(f"  Image AE final val loss: {img_checkpoint['final_val_loss']:.6f}")
    
    # Load audio autoencoder
    audio_ae_path = os.path.join(args.pretrained_dir, f"{audio_model_prefix}_audio_ae.pth")
    if not os.path.exists(audio_ae_path):
        raise FileNotFoundError(f"Audio autoencoder not found: {audio_ae_path}")
    
    audio_checkpoint = torch.load(audio_ae_path, map_location='cpu')
    # Audio was trained with full_spectrum=True, so always use that
    audio_ae = AudioAutoencoder(latent_dim=audio_checkpoint['latent_dim'], full_spectrum=True)
    audio_ae.encoder.load_state_dict(audio_checkpoint['encoder_state_dict'])
    audio_ae.decoder.load_state_dict(audio_checkpoint['decoder_state_dict'])
    audio_ae.to(DEVICE)
    audio_ae.eval()
    print(f"✓ Loaded audio autoencoder from: {audio_ae_path}")
    print(f"  Audio AE latent dim: {audio_checkpoint['latent_dim']}")
    print(f"  Audio AE final val loss: {audio_checkpoint['final_val_loss']:.6f}")

    # --- 5. Extract Representations (train only) ---
    print("\n" + "="*80)
    print("Extracting representations from pretrained encoders...")
    print("="*80)
    
    # Convert to tensors
    # Image encoder expects 2D images (B, 1, 28, 28), not flattened
    train_images_tensor = torch.FloatTensor(train_images).unsqueeze(1).to(DEVICE)  # Add channel dim

    # Prepare audio tensor robustly so it has shape (N, 1, H, W)
    audio_np = train_audio_processed
    if isinstance(audio_np, torch.Tensor):
        audio_np = audio_np.cpu().numpy()

    if audio_np is None:
        raise ValueError("train_audio_processed is None")

    # Handle possible shapes:
    #  - (N, H, W) -> add channel dim
    #  - (N, C, H, W) -> if C != 1, average across channels
    #  - (N, L) where L == H*W -> reshape to (N, H, W)
    if audio_np.ndim == 2:
        # (N, L) -- try to reshape to 112x112 if possible
        N, L = audio_np.shape
        if L == 112 * 112:
            audio_np = audio_np.reshape(N, 112, 112)
            audio_np = audio_np[:, np.newaxis, :, :]
        else:
            # treat as single-channel 1D (N, L) -> add channel dim
            audio_np = audio_np[:, np.newaxis, :]
    elif audio_np.ndim == 3:
        # (N, H, W) -> add channel dim
        audio_np = audio_np[:, np.newaxis, :, :]
    elif audio_np.ndim == 4:
        # (N, C, H, W) -> if multiple channels, average them to single channel
        if audio_np.shape[1] != 1:
            audio_np = np.mean(audio_np, axis=1, keepdims=True)
    else:
        raise ValueError(f"Unsupported audio array shape: {audio_np.shape}")

    train_audio_tensor = torch.FloatTensor(audio_np).to(DEVICE)
    
    # Extract representations in batches to avoid memory issues
    batch_size = 512
    image_reps_list = []
    audio_reps_list = []
    
    with torch.no_grad():
        for i in range(0, len(train_images_tensor), batch_size):
            batch_img = train_images_tensor[i:i+batch_size]
            batch_audio = train_audio_tensor[i:i+batch_size]
            
            img_rep = img_ae.encoder(batch_img)
            audio_rep = audio_ae.encoder(batch_audio)
            
            image_reps_list.append(img_rep.cpu().numpy())
            audio_reps_list.append(audio_rep.cpu().numpy())
    
    train_image_reps = np.concatenate(image_reps_list, axis=0)
    train_audio_reps = np.concatenate(audio_reps_list, axis=0)
    
    print(f"Train image representations shape: {train_image_reps.shape}")
    print(f"Train audio representations shape: {train_audio_reps.shape}")
    
    # Concatenate representations
    train_concat_reps = np.concatenate([train_image_reps, train_audio_reps], axis=1)
    print(f"Train concatenated representations shape: {train_concat_reps.shape}")

    # --- 5b. VERIFICATION: Load saved representation files ---
    print("\n" + "="*80)
    print("Loading saved representation files for verification...")
    print("="*80)
    
    # Load representations saved by the training script (056_audiomnist.py)
    # NOTE: These match the naming used in 053_audiomnist_analysis.py
    # Format: audiomnist2{_fullspec}_rseed-{seed}_rep{i}.npy in 03_results/models
    #saved_seed = 42  # reps were generated with seed 42
    results_dir = '03_results/models'  # match 053_audiomnist_analysis.py default
    # CRITICAL: The saved files always use '_fullspec' suffix regardless of args.full_spectrum
    # because they were generated from the full spectrum training
    #model_prefix_for_reps = f"audiomnist2_fullspec_rseed-{saved_seed}"
    model_prefix_for_reps = f"audiomnist_fullspec_rseed-{args.seed}"
    
    saved_reps = []
    for i in range(3):  # shared, image, audio (rep0, rep1, rep2)
        rep_path = os.path.join(results_dir, f"{model_prefix_for_reps}_rep{i}.npy")
        if os.path.exists(rep_path):
            try:
                rep = np.load(rep_path)
                saved_reps.append(rep)
                print(f"✓ Loaded saved rep {i} from: {rep_path} (shape: {rep.shape})")
            except Exception as e:
                saved_reps.append(None)
                print(f"Failed to load rep file {rep_path}: {e}")
        else:
            saved_reps.append(None)
            print(f"  Warning: Saved rep {i} not found at: {rep_path}")
    
    # The saved representations should be: [shared, image_specific, audio_specific]
    # and they should match train+test concatenated order (same as labels_concat in 053)
    # We need to use only the train portion for classification
    saved_shared = saved_reps[0][:n_train] if saved_reps[0] is not None else None
    saved_image = saved_reps[1][:n_train] if saved_reps[1] is not None else None
    saved_audio = saved_reps[2][:n_train] if saved_reps[2] is not None else None

    """
    # --- 7. Compute Classification Accuracies ---
    print("\n" + "="*80)
    print("Computing digit classification accuracies (from encoder extractions)...")
    print("="*80)
    
    results = []
    
    # Individual modality representations (from encoders)
    # Use GaussianNB for reporting (matching 053_audiomnist_analysis.py)
    img_accs = compute_multi_classifier_accs(train_image_reps, train_labels_matched)
    acc_img_gnb = img_accs.get('gnb')
    acc_img_knn = img_accs.get('knn')
    acc_img_logistic = img_accs.get('logistic')
    print(f"Image representations (encoder) - GNB: {acc_img_gnb:.4f}, KNN: {acc_img_knn:.4f}, Logistic: {acc_img_logistic:.4f}" if acc_img_gnb is not None else "Image representations (encoder): Failed")
    results.append({
        'representation': 'image_encoder',
        'accuracy': acc_img_gnb,  # Use GaussianNB as primary metric
        'accuracy_knn': acc_img_knn,
        'accuracy_logistic': acc_img_logistic,
        'dim': train_image_reps.shape[1],
        'source': 'encoder'
    })
    
    audio_accs = compute_multi_classifier_accs(train_audio_reps, train_labels_matched)
    acc_audio_gnb = audio_accs.get('gnb')
    acc_audio_knn = audio_accs.get('knn')
    acc_audio_logistic = audio_accs.get('logistic')
    print(f"Audio representations (encoder) - GNB: {acc_audio_gnb:.4f}, KNN: {acc_audio_knn:.4f}, Logistic: {acc_audio_logistic:.4f}" if acc_audio_gnb is not None else "Audio representations (encoder): Failed")
    results.append({
        'representation': 'audio_encoder',
        'accuracy': acc_audio_gnb,  # Use GaussianNB as primary metric
        'accuracy_logistic': acc_audio_logistic,
        'accuracy_knn': acc_audio_knn,
        'dim': train_audio_reps.shape[1],
        'source': 'encoder'
    })
    
    # Concatenated representations
    concat_accs = compute_multi_classifier_accs(train_concat_reps, train_labels_matched)
    acc_concat_gnb = concat_accs.get('gnb')
    acc_concat_knn = concat_accs.get('knn')
    acc_concat_logistic = concat_accs.get('logistic')
    print(f"Concatenated representations (encoder) - GNB: {acc_concat_gnb:.4f}, KNN: {acc_concat_knn:.4f}, Logistic: {acc_concat_logistic:.4f}" if acc_concat_gnb is not None else "Concatenated representations (encoder): Failed")
    results.append({
        'representation': 'concatenated',
        'accuracy': acc_concat_gnb,  # Use GaussianNB as primary metric
        'accuracy_logistic': acc_concat_logistic,
        'accuracy_knn': acc_concat_knn,
        'dim': train_concat_reps.shape[1],
        'source': 'encoder'
    })
    
    # Raw image data
    raw_accs = compute_multi_classifier_accs(train_images_flat, train_labels_matched)
    acc_raw_img_gnb = raw_accs.get('gnb')
    acc_raw_img_knn = raw_accs.get('knn')
    acc_raw_img_logistic = raw_accs.get('logistic')
    print(f"Raw image data - GNB: {acc_raw_img_gnb:.4f}, KNN: {acc_raw_img_knn:.4f}, Logistic: {acc_raw_img_logistic:.4f}" if acc_raw_img_gnb is not None else "Raw image data: Failed")
    results.append({
        'representation': 'raw_image',
        'accuracy': acc_raw_img_gnb,
        'accuracy_logistic': acc_raw_img_logistic,
        'accuracy_knn': acc_raw_img_knn,
        'dim': train_images_flat.shape[1],
        'source': 'raw'
    })
    
    # --- 7b. VERIFICATION: Compute accuracies from saved representation files ---
    print("\n" + "="*80)
    print("Computing digit classification accuracies (from saved files - VERIFICATION)...")
    print("="*80)
    
    if saved_shared is not None:
        saved_shared_accs = compute_multi_classifier_accs(saved_shared, train_labels_matched)
        acc_saved_shared_gnb = saved_shared_accs.get('gnb')
        acc_saved_shared_knn = saved_shared_accs.get('knn')
        acc_saved_shared_logistic = saved_shared_accs.get('logistic')
        print(f"Saved Shared representation - GNB: {acc_saved_shared_gnb:.4f}, KNN: {acc_saved_shared_knn:.4f}, Logistic: {acc_saved_shared_logistic:.4f}" if acc_saved_shared_gnb is not None else "Saved Shared representation: Failed")
        results.append({
            'representation': 'shared_saved',
            'accuracy': acc_saved_shared_gnb,
            'accuracy_logistic': acc_saved_shared_logistic,
            'accuracy_knn': acc_saved_shared_knn,
            'dim': saved_shared.shape[1],
            'source': 'saved_file'
        })
    
    if saved_image is not None:
        saved_image_accs = compute_multi_classifier_accs(saved_image, train_labels_matched)
        acc_saved_img_gnb = saved_image_accs.get('gnb')
        acc_saved_img_knn = saved_image_accs.get('knn')
        acc_saved_img_logistic = saved_image_accs.get('logistic')
        print(f"Saved Image-specific representation - GNB: {acc_saved_img_gnb:.4f}, KNN: {acc_saved_img_knn:.4f}, Logistic: {acc_saved_img_logistic:.4f}" if acc_saved_img_gnb is not None else "Saved Image-specific representation: Failed")
        results.append({
            'representation': 'image_specific_saved',
            'accuracy': acc_saved_img_gnb,
            'accuracy_logistic': acc_saved_img_logistic,
            'accuracy_knn': acc_saved_img_knn,
            'dim': saved_image.shape[1],
            'source': 'saved_file'
        })
    
    if saved_audio is not None:
        saved_audio_accs = compute_multi_classifier_accs(saved_audio, train_labels_matched)
        acc_saved_audio_gnb = saved_audio_accs.get('gnb')
        acc_saved_audio_knn = saved_audio_accs.get('knn')
        acc_saved_audio_logistic = saved_audio_accs.get('logistic')
        print(f"Saved Audio-specific representation - GNB: {acc_saved_audio_gnb:.4f}, KNN: {acc_saved_audio_knn:.4f}, Logistic: {acc_saved_audio_logistic:.4f}" if acc_saved_audio_gnb is not None else "Saved Audio-specific representation: Failed")
        results.append({
            'representation': 'audio_specific_saved',
            'accuracy': acc_saved_audio_gnb,
            'accuracy_logistic': acc_saved_audio_logistic,
            'accuracy_knn': acc_saved_audio_knn,
            'dim': saved_audio.shape[1],
            'source': 'saved_file'
        })
    
    # Concatenated saved (shared + image_specific) and (shared + audio_specific)
    if saved_shared is not None and saved_image is not None:
        saved_concat_img = np.concatenate([saved_shared, saved_image], axis=1)
        saved_concat_img_accs = compute_multi_classifier_accs(saved_concat_img, train_labels_matched)
        acc_saved_concat_img_gnb = saved_concat_img_accs.get('gnb')
        acc_saved_concat_img_knn = saved_concat_img_accs.get('knn')
        acc_saved_concat_img_logistic = saved_concat_img_accs.get('logistic')
        print(f"Saved Shared+Image concat - GNB: {acc_saved_concat_img_gnb:.4f}, KNN: {acc_saved_concat_img_knn:.4f}, Logistic: {acc_saved_concat_img_logistic:.4f}" if acc_saved_concat_img_gnb is not None else "Saved Shared+Image concat: Failed")
        results.append({
            'representation': 'shared_image_concat_saved',
            'accuracy': acc_saved_concat_img_gnb,
            'accuracy_logistic': acc_saved_concat_img_logistic,
            'accuracy_knn': acc_saved_concat_img_knn,
            'dim': saved_concat_img.shape[1],
            'source': 'saved_file'
        })
    
    if saved_shared is not None and saved_audio is not None:
        saved_concat_audio = np.concatenate([saved_shared, saved_audio], axis=1)
        saved_concat_audio_accs = compute_multi_classifier_accs(saved_concat_audio, train_labels_matched)
        acc_saved_concat_audio_gnb = saved_concat_audio_accs.get('gnb')
        acc_saved_concat_audio_knn = saved_concat_audio_accs.get('knn')
        acc_saved_concat_audio_logistic = saved_concat_audio_accs.get('logistic')
        print(f"Saved Shared+Audio concat - GNB: {acc_saved_concat_audio_gnb:.4f}, KNN: {acc_saved_concat_audio_knn:.4f}, Logistic: {acc_saved_concat_audio_logistic:.4f}" if acc_saved_concat_audio_gnb is not None else "Saved Shared+Audio concat: Failed")
        results.append({
            'representation': 'shared_audio_concat_saved',
            'accuracy': acc_saved_concat_audio_gnb,
            'accuracy_logistic': acc_saved_concat_audio_logistic,
            'accuracy_knn': acc_saved_concat_audio_knn,
            'dim': saved_concat_audio.shape[1],
            'source': 'saved_file'
        })
    
    # concat all three saved reps
    if saved_shared is not None and saved_image is not None and saved_audio is not None:
        saved_all_concat = np.concatenate([saved_shared, saved_image, saved_audio], axis=1)
        saved_all_concat_accs = compute_multi_classifier_accs(saved_all_concat, train_labels_matched)
        acc_saved_all_concat_gnb = saved_all_concat_accs.get('gnb')
        acc_saved_all_concat_knn = saved_all_concat_accs.get('knn')
        acc_saved_all_concat_logistic = saved_all_concat_accs.get('logistic')
        print(f"Saved All concat - GNB: {acc_saved_all_concat_gnb:.4f}, KNN: {acc_saved_all_concat_knn:.4f}, Logistic: {acc_saved_all_concat_logistic:.4f}" if acc_saved_all_concat_gnb is not None else "Saved All concat: Failed")
        results.append({
            'representation': 'all_concat_saved',
            'accuracy': acc_saved_all_concat_gnb,
            'accuracy_logistic': acc_saved_all_concat_logistic,
            'accuracy_knn': acc_saved_all_concat_knn,
            'dim': saved_all_concat.shape[1],
            'source': 'saved_file'
        })
    """

    # --- 7c. Baseline Methods ---
    print("\n" + "="*80)
    print("Computing baseline method decompositions...")
    print("="*80)
    
    # Initialize results list
    results = []
    
    # Prepare data for baselines (use train representations from encoders)
    data_for_baselines = [
        torch.FloatTensor(train_image_reps),
        torch.FloatTensor(train_audio_reps)
    ]
    
    # Baseline methods to run (excluding CCA and DIVAS as requested)
    baseline_methods = [
        #('PPD', lambda: PPD()),
        ('JIVE', lambda: JIVE()),
        ('AJIVE', lambda: AJIVE()),
        ('SLIDE', lambda: SLIDE()),
        ('ShIndICA', lambda: ShIndICA(joint_rank_options=[1, 5, 10, 20])),
    ]
    
    for method_name, method_constructor in baseline_methods:
        print(f"\n--- {method_name} Baseline ---")
        try:
            method = method_constructor()
            decomposed_reps, rank_info = method.decompose(data_for_baselines)
            
            print(f"{method_name} subspaces: {list(decomposed_reps.keys())}")
            print(f"{method_name} rank info: {rank_info}")
            
            # Convert to numpy and compute accuracies for each subspace
            for key, val in decomposed_reps.items():
                subspace_rep = val.cpu().numpy() if torch.is_tensor(val) else val
                subspace_accs = compute_multi_classifier_accs(subspace_rep, train_labels_matched)
                
                if subspace_accs['gnb'] is not None:
                    print(f"{method_name} {key} - GNB: {subspace_accs['gnb']:.4f}, KNN: {subspace_accs['knn']:.4f}, Logistic: {subspace_accs['logistic']:.4f}")
                
                results.append({
                    'representation': f'{method_name.lower()}_{key}',
                    'accuracy': subspace_accs['gnb'],
                    'accuracy_knn': subspace_accs['knn'],
                    'accuracy_logistic': subspace_accs['logistic'],
                    'dim': subspace_rep.shape[1],
                    'source': method_name.lower()
                })
            
            # Concatenate all subspaces and compute classification accuracy
            all_subspaces = [val.cpu().numpy() if torch.is_tensor(val) else val for val in decomposed_reps.values()]
            concat_subspaces = np.concatenate(all_subspaces, axis=1)
            
            concat_accs = compute_multi_classifier_accs(concat_subspaces, train_labels_matched)
            
            if concat_accs['gnb'] is not None:
                print(f"{method_name} All subspaces concatenated - GNB: {concat_accs['gnb']:.4f}, KNN: {concat_accs['knn']:.4f}, Logistic: {concat_accs['logistic']:.4f}")
            
            results.append({
                'representation': f'{method_name.lower()}_all_concat',
                'accuracy': concat_accs['gnb'],
                'accuracy_knn': concat_accs['knn'],
                'accuracy_logistic': concat_accs['logistic'],
                'dim': concat_subspaces.shape[1],
                'source': method_name.lower()
            })
            
        except Exception as e:
            print(f"{method_name} decomposition failed: {e}")
    
    """
    # --- 8. Save Results ---
    print("\n" + "="*80)
    print("Saving results...")
    print("="*80)
    
    os.makedirs('03_results/processed', exist_ok=True)
    save_prefix = f"audiomnist_rseed-{args.seed}"
    csv_file = f"03_results/processed/{save_prefix}_pretrained_classification.csv"
    
    with open(csv_file, 'w', newline='') as f:
        fieldnames = ['representation', 'accuracy', 'accuracy_logistic', 'accuracy_knn', 'dim', 'source']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    
    print(f"✓ Saved results to: {csv_file}")
    
    # Print summary table
    print("\n" + "="*80)
    print("Classification Accuracy Summary:")
    print("="*80)
    print(f"{'Representation':<35} {'Acc (GNB)':<12} {'Acc (Logistic)':<15} {'Acc (KNN)':<12} {'Dim':<15} {'Source':<15}")
    print("-" * 89)
    for row in results:
        acc_gnb_str = f"{row['accuracy']:.4f}" if row.get('accuracy') is not None else "N/A"
        acc_logistic_str = f"{row.get('accuracy_logistic'):.4f}" if row.get('accuracy_logistic') is not None else "N/A"
        acc_knn_str = f"{row.get('accuracy_knn'):.4f}" if row.get('accuracy_knn') is not None else "N/A"
        print(f"{row['representation']:<35} {acc_gnb_str:<12} {acc_logistic_str:<15} {acc_knn_str:<12} {row['dim']:<15} {row.get('source', 'N/A'):<15}")
    
    print("\n" + "="*80)
    print("Analysis complete!")
    print("="*80)
    """


if __name__ == '__main__':
    main()
