#!/usr/bin/env python3
"""Plot AVMNIST representations saved by 056_audiomnist.py.

Reconstructs the matched train/test ordering using the same sampling logic
from `056_audiomnist.py` (same seed and per-class counts), loads the saved
representations, creates 3 panels (Shared, Image, Audio) and colors them by
digit and by speaker. Also computes 1-NN classification accuracies and saves
them to CSV.
"""
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.preprocessing import StandardScaler
import json
import pandas as pd
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from torchvision.datasets import MNIST


def match_datasets_by_label(images, img_labels, audio, audio_labels, speaker_labels, n_samples_per_class=6000, seed=42):
    np.random.seed(seed)
    matched_images = []
    matched_audio = []
    matched_labels = []
    matched_speaker_labels = []
    for digit in range(10):
        img_indices = np.where(img_labels == digit)[0]
        audio_indices = np.where(audio_labels == digit)[0]
        n_samples = min(len(img_indices), len(audio_indices), n_samples_per_class)
        if n_samples > 0:
            img_sample_indices = np.random.choice(img_indices, n_samples, replace=False)
            audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
            matched_images.append(images[img_sample_indices])
            matched_audio.append(audio[audio_sample_indices])
            matched_labels.extend([digit] * n_samples)
            matched_speaker_labels.append(speaker_labels[audio_sample_indices])
    return (np.concatenate(matched_images, axis=0),
            np.concatenate(matched_audio, axis=0),
            np.array(matched_labels),
            np.concatenate(matched_speaker_labels, axis=0))


def load_reps(results_dir, model_prefix, max_reps=3):
    reps = []
    for i in range(max_reps):
        path = os.path.join(results_dir, f"{model_prefix}_rep{i}.npy")
        if not os.path.exists(path):
            print(f"Warning: rep file not found: {path}")
            reps.append(None)
            continue
        reps.append(np.load(path))
    return reps


def reduce_to_2d(rep, method='pca'):
    if rep is None:
        return None
    if rep.shape[1] == 2:
        return rep
    if rep.shape[1] == 1:
        # expand to tiny jitter second dim for plotting
        rng = np.random.RandomState(0)
        return np.concatenate([rep, (rng.rand(rep.shape[0], 1) - 0.5) * 1e-3], axis=1)
    if method == 'pca':
        pca = PCA(n_components=2)
        return pca.fit_transform(rep)
    else:
        # fallback to PCA if other methods unavailable
        pca = PCA(n_components=2)
        return pca.fit_transform(rep)


def compute_knn_acc(rep, labels, n_neighbors=1):
    if rep is None or labels is None:
        return None
    try:
        X = rep
        y = np.asarray(labels)
        # If labels are strings or non-numeric, encode them to integers
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 5:
            return None
        split = int(0.8 * n)
        knn = KNeighborsClassifier(n_neighbors=n_neighbors)
        knn.fit(X[:split], y[:split])
        acc = knn.score(X[split:], y[split:])
        return float(acc)
    except Exception as e:
        print(f"KNN accuracy computation failed: {e}")
        return None


def compute_multi_classifier_accs(rep, labels):
    """Compute accuracies for a small set of simple classifiers using an 80/20 split.
    Returns a dict with keys: knn, logistic, gnb (values are floats or None).
    """
    out = {'knn': None, 'logistic': None, 'gnb': None}
    if rep is None or labels is None:
        return out
    try:
        X = np.asarray(rep)
        y = np.asarray(labels)
        # encode non-numeric labels
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 10:
            return out
        split = int(0.8 * n)
        X_train, X_test = X[:split], X[split:]
        y_train, y_test = y[:split], y[split:]
        # KNN
        try:
            knn = KNeighborsClassifier(n_neighbors=1)
            knn.fit(X_train, y_train)
            out['knn'] = float(knn.score(X_test, y_test))
        except Exception:
            out['knn'] = None
        # Logistic Regression: scale data and increase max_iter; try a robust solver with fallback
        try:
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            try:
                log = LogisticRegression(max_iter=2000, solver='saga', multi_class='multinomial')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
            except Exception:
                # fallback to lbfgs
                log = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='auto')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
        except Exception:
            out['logistic'] = None
        # Gaussian NB (scale is optional but safe)
        try:
            try:
                # use the same scaler if available
                X_train_g = X_train
                X_test_g = X_test
                scaler = StandardScaler()
                X_train_g = scaler.fit_transform(X_train_g)
                X_test_g = scaler.transform(X_test_g)
            except Exception:
                X_train_g = X_train
                X_test_g = X_test
            gnb = GaussianNB()
            gnb.fit(X_train_g, y_train)
            out['gnb'] = float(gnb.score(X_test_g, y_test))
        except Exception:
            out['gnb'] = None
        return out
    except Exception as e:
        print(f"Multi-classifier accuracy failed: {e}")
        return out


def plot_panel(ax, rep2d, labels, title, cmap='tab10'):
    labels_arr = np.asarray(labels)
    # If labels are non-numeric, encode them to integers for coloring
    if labels_arr.dtype.kind in {'U', 'S', 'O'}:
        le = LabelEncoder()
        try:
            encoded = le.fit_transform(labels_arr.astype(str))
        except Exception:
            # fallback: map unique strings to integers via dict
            uniq = {v: i for i, v in enumerate(np.unique(labels_arr.astype(str)))}
            encoded = np.array([uniq[str(v)] for v in labels_arr], dtype=int)
        sc = ax.scatter(rep2d[:, 0], rep2d[:, 1], c=encoded, cmap=cmap, s=3, alpha=0.7)
        # Add colorbar with tick labels when number of classes is reasonable
        try:
            cbar = plt.colorbar(sc, ax=ax)
            n_classes = len(le.classes_)
            if n_classes <= 20:
                ticks = np.arange(n_classes)
                cbar.set_ticks(ticks)
                cbar.set_ticklabels(le.classes_)
        except Exception:
            pass
    else:
        sc = ax.scatter(rep2d[:, 0], rep2d[:, 1], c=labels_arr, cmap=cmap, s=3, alpha=0.7)
        try:
            plt.colorbar(sc, ax=ax)
        except Exception:
            pass

    ax.set_title(title)
    ax.set_xlabel('Dim 1')
    ax.set_ylabel('Dim 2')
    return sc


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--results-dir', default='03_results/models')
    parser.add_argument('--data-dir', default='01_data/avmnist_data_from_source')
    parser.add_argument('--full_spectrum', action='store_true')
    parser.add_argument('--train_samples_per_digit', type=int, default=5000)
    parser.add_argument('--test_samples_per_digit', type=int, default=1000)
    args = parser.parse_args()

    model_prefix = f"audiomnist2{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}"
    print(f"Looking for reps with prefix: {model_prefix} in {args.results_dir}")

    # Rebuild matched labels using the same procedure as 056_audiomnist.py
    # Load source arrays
    train_audio = np.load(os.path.join(args.data_dir, 'audio', 'train_data.npy'))
    train_labels = np.load(os.path.join(args.data_dir, 'train_labels.npy'))
    train_speakers = np.load(os.path.join(args.data_dir, 'train_speaker_labels.npy'))

    test_audio = np.load(os.path.join(args.data_dir, 'audio', 'test_data.npy'))
    test_labels = np.load(os.path.join(args.data_dir, 'test_labels.npy'))
    test_speakers = np.load(os.path.join(args.data_dir, 'test_speaker_labels.npy'))

    print('Loaded source audio and label arrays; reconstructing matched dataset order...')

    # For train
    _, _, matched_train_labels, matched_train_speakers = match_datasets_by_label(
        np.zeros((train_audio.shape[0], 1)), train_labels, train_audio, train_labels, train_speakers,
        n_samples_per_class=args.train_samples_per_digit, seed=args.seed)

    # For test
    _, _, matched_test_labels, matched_test_speakers = match_datasets_by_label(
        np.zeros((test_audio.shape[0], 1)), test_labels, test_audio, test_labels, test_speakers,
        n_samples_per_class=args.test_samples_per_digit, seed=args.seed)

    # Concatenate to match training data ordering used by 056 (train then test)
    labels_concat = np.concatenate([matched_train_labels, matched_test_labels], axis=0)
    speakers_concat = np.concatenate([matched_train_speakers, matched_test_speakers], axis=0)

    print(f"Reconstructed labels: total samples = {labels_concat.shape[0]}")

    # --- Load metadata and map speaker IDs to additional labels ---
    def load_meta_candidates(base_dirs):
        # try several locations for audioMNIST_meta.txt/json
        for base in base_dirs:
            for fname in ('audioMNIST_meta.txt', 'audioMNIST_meta.json', 'audioMNIST_meta'):
                p = os.path.join(base, fname)
                if os.path.exists(p):
                    try:
                        text = open(p, 'r', encoding='utf-8').read()
                        try:
                            return json.loads(text)
                        except Exception:
                            try:
                                # fallback to python literal style
                                import ast
                                return ast.literal_eval(text)
                            except Exception:
                                pass
                    except Exception:
                        pass
        return None

    meta = load_meta_candidates([args.data_dir, '01_data/processed/avmnist', 'AudioMNIST', '01_data/processed/avmnist'])
    if meta is None:
        print('Warning: audioMNIST metadata not found. gender/native/accent labels will be unavailable.')

    def map_meta_field(speakers_array, field):
        if meta is None:
            return None
        out = []
        for s in speakers_array:
            # metadata keys may be 1-indexed strings
            candidates = [str(s), str(int(s) + 1)]
            val = None
            for c in candidates:
                if c in meta:
                    val = meta[c].get(field, None)
                    break
            if val is None:
                # try integer key access if meta keys are ints
                try:
                    mv = meta.get(int(s))
                    if mv is not None:
                        val = mv.get(field, None)
                except Exception:
                    pass
            # normalize missing values to a sentinel string to avoid mixed types
            out.append(val if val is not None else 'UNKNOWN')
        # return as numpy array of strings
        return np.array(out, dtype=object)
    
    gender_concat = map_meta_field(speakers_concat, 'gender')
    native_concat = map_meta_field(speakers_concat, 'native speaker')
    accent_concat = map_meta_field(speakers_concat, 'accent')
    origin_concat = map_meta_field(speakers_concat, 'origin')
    # origins need to be split to only include the continent (first word)
    if origin_concat is not None:
        origin_concat = np.array([str(o).split()[0] if o != 'UNKNOWN' else o for o in origin_concat], dtype=object)
    room_concat = map_meta_field(speakers_concat, 'recordingroom')

    # Print brief meta summary
    if gender_concat is not None:
        print('Gender unique:', np.unique(gender_concat))
    if native_concat is not None:
        print('Native speaker unique:', np.unique(native_concat))
    if accent_concat is not None:
        print('Accent sample unique (first 20):', np.unique(accent_concat)[:20])

    # Load reps
    reps = load_reps(args.results_dir, model_prefix, max_reps=3)
    # If any rep is None, abort gracefully
    if all(r is None for r in reps):
        print('No representations found; aborting.')
        return

    # Ensure labels length matches rep length (use first non-None rep)
    rep0 = next(r for r in reps if r is not None)
    n_rep = rep0.shape[0]
    if labels_concat.shape[0] != n_rep:
        print(f"WARNING: reconstructed labels length {labels_concat.shape[0]} != rep length {n_rep}")
        # Try to truncate or pad labels to fit reps
        N = min(labels_concat.shape[0], n_rep)
        labels_concat = labels_concat[:N]
        speakers_concat = speakers_concat[:N]
        reps = [ (r[:N] if r is not None else None) for r in reps ]
        print(f"Truncated to N={N}")

    # Create rows for each label type we have available
    label_rows = [
        ('digits', labels_concat, 'digit_acc'),
        ('speakers', speakers_concat, 'speaker_acc')
    ]
    if origin_concat is not None:
        label_rows.append(('origin', origin_concat, 'origin_acc'))
    if room_concat is not None:
        label_rows.append(('room', room_concat, 'room_acc'))
    if gender_concat is not None:
        label_rows.append(('gender', gender_concat, 'gender_acc'))
    if native_concat is not None:
        label_rows.append(('native', native_concat, 'native_acc'))
    if accent_concat is not None:
        label_rows.append(('accent', accent_concat, 'accent_acc'))

    n_rows = len(label_rows)
    n_cols = len(reps)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
    axes = np.atleast_2d(axes)

    accs = []
    comp_rows = []
    for i, rep in enumerate(reps):
        if rep is None:
            for r in range(n_rows):
                axes[r, i].text(0.5, 0.5, 'missing rep', ha='center')
            accs.append({'rep': i, 'digit_acc': None, 'speaker_acc': None})
            continue

        # If rep is torch tensor saved earlier, it will be numpy already
        rep_arr = np.asarray(rep)

        # Compute accuracies on the raw rep (not reduced)
        digit_acc_raw = compute_knn_acc(rep_arr, labels_concat)
        speaker_acc_raw = compute_knn_acc(rep_arr, speakers_concat)
        gender_acc_raw = compute_knn_acc(rep_arr, gender_concat) if gender_concat is not None else None
        native_acc_raw = compute_knn_acc(rep_arr, native_concat) if native_concat is not None else None
        accent_acc_raw = compute_knn_acc(rep_arr, accent_concat) if accent_concat is not None else None
        origin_acc_raw = compute_knn_acc(rep_arr, origin_concat) if origin_concat is not None else None
        room_acc_raw = compute_knn_acc(rep_arr, room_concat) if room_concat is not None else None

        # Create 2D reduction only for plotting; do NOT compute accuracies on 2D
        rep2d = reduce_to_2d(rep_arr, method='pca')

        accs.append({'rep': i,
             'digit_acc_raw': digit_acc_raw, 'speaker_acc_raw': speaker_acc_raw,
           'gender_acc_raw': gender_acc_raw, 'native_acc_raw': native_acc_raw, 'accent_acc_raw': accent_acc_raw,
           'origin_acc_raw': origin_acc_raw, 'room_acc_raw': room_acc_raw})

        # Compute multi-classifier comparisons for each label type and store rows for CSV
        for lname, lvals, acc_key in label_rows:
            try:
                accs_dict = compute_multi_classifier_accs(rep_arr, lvals)
            except Exception:
                accs_dict = {'knn': None, 'logistic': None, 'gnb': None}
            comp_rows.append({
                'rep': i,
                'label_type': lname,
                'knn_acc': accs_dict.get('knn'),
                'logistic_acc': accs_dict.get('logistic'),
                'gnb_acc': accs_dict.get('gnb')
            })

        # Plot each available label row for this rep
        for r, (lname, lvals, acc_key) in enumerate(label_rows):
            acc_val = None
            if accs and len(accs) > 0 and accs[-1]['rep'] == i:
                # already appended; get the last entry
                acc_val = accs[-1].get(acc_key)
            # Build title with accuracy when available (prefer 2D accuracy for plot titles)
            # Use raw accuracies for all rows (do not compute 2D accuracies)
            acc_val_for_title = None
            if accs and len(accs) > 0 and accs[-1]['rep'] == i:
                # map acc_key to the raw name stored in accs
                raw_key = acc_key + '_raw' if not acc_key.endswith('_raw') else acc_key
                acc_val_for_title = accs[-1].get(raw_key)
            if acc_val_for_title is not None:
                title = f"Rep {i} - {lname}\n(Acc: {acc_val_for_title:.2f})"
            else:
                title = f"Rep {i} - {lname}"

            # choose cmap for certain label types
            if lname == 'digits':
                cmap = 'tab10'
            elif lname == 'speakers':
                unique_speakers = len(np.unique(lvals))
                cmap = 'tab20' if unique_speakers <= 20 else 'viridis'
            elif lname in ('gender', 'native'):
                cmap = 'Set1'
            else:
                cmap = 'viridis'

            plot_panel(axes[r, i], rep2d, lvals, title=title, cmap=cmap)

    plt.tight_layout()

    out_dir = '03_results/plots'
    os.makedirs(out_dir, exist_ok=True)
    out_file = os.path.join(out_dir, f'audiomnist2_analysis_rseed-{args.seed}{"_fullspec" if args.full_spectrum else ""}.png')
    plt.savefig(out_file, dpi=200, bbox_inches='tight')
    print(f'Saved figure to: {out_file}')

    # --- Create a second figure with ranks, R² and thumbnail subspace panels (like 052_avmnist_plots_real.py)
    def load_aligned_mnist_images(model_prefix, results_dir, labels=None, mnist_root='01_data/processed/MNIST'):
        """Try to load MNIST train+test images aligned to AVMNIST using the saved mapping.
        Returns aligned_images (N, H, W) or None on failure.
        """
        mapping_file = os.path.join(results_dir, f"{model_prefix}_mnist_mapping.npz")
        mapping = None
        if not os.path.exists(mapping_file):
            print(f"Mapping file for thumbnails not found: {mapping_file} - will try label-based alignment if labels provided")
        else:
            try:
                mapping = np.load(mapping_file, allow_pickle=True)
            except Exception as e:
                print(f"Failed to load mapping file {mapping_file}: {e}")

        try:
            mnist_train = MNIST(root=mnist_root, train=True, download=False)
            mnist_test = MNIST(root=mnist_root, train=False, download=False)
            mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0
            mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
            all_mnist = np.concatenate([mnist_train_images, mnist_test_images], axis=0)
        except Exception as e:
            print(f"Could not load MNIST images for thumbnails: {e}")
            return None

        try:
            if mapping is not None:
                if 'mnist_train_indices' in mapping and 'mnist_test_indices' in mapping:
                    train_idx = np.array(mapping['mnist_train_indices'], dtype=np.int64)
                    test_idx = np.array(mapping['mnist_test_indices'], dtype=np.int64)
                    all_idx = list(train_idx) + list(test_idx)
                    aligned = all_mnist[all_idx]
                    return aligned
                elif 'mnist_image_indices' in mapping:
                    mnist_indices = np.array(mapping['mnist_image_indices'], dtype=np.int64)
                    # create deterministic test mapping fallback (simple append of next test indices)
                    # This mimics the deterministic behavior used elsewhere; may not be exact for all mappings
                    n_train = len(mnist_indices)
                    remaining = all_mnist.shape[0] - n_train
                    test_idx = list(range(n_train, n_train + remaining))[:]
                    all_idx = list(mnist_indices) + test_idx[:]
                    aligned = all_mnist[all_idx]
                    return aligned
            # If mapping not available but labels were provided, create a best-effort aligned_images
            if labels is not None:
                labels_arr = np.asarray(labels).astype(int)
                # Build indices per digit from MNIST combined
                all_mnist_labels = np.concatenate([mnist_train.targets.numpy(), mnist_test.targets.numpy()], axis=0)
                all_mnist_images_combined = np.concatenate([mnist_train_images, mnist_test_images], axis=0)
                per_digit_indices = {d: np.where(all_mnist_labels == d)[0] for d in range(10)}
                counters = {d: 0 for d in range(10)}
                aligned_list = []
                for lab in labels_arr:
                    idxs = per_digit_indices.get(int(lab), np.array([]))
                    if idxs.size == 0:
                        # fallback: use a random index from all images
                        chosen = 0
                    else:
                        c = counters[int(lab)]
                        chosen = idxs[c % idxs.size]
                        counters[int(lab)] += 1
                    aligned_list.append(all_mnist_images_combined[chosen])
                try:
                    return np.stack(aligned_list, axis=0)
                except Exception:
                    return np.array(aligned_list)
        except Exception as e:
            print(f"Error constructing aligned MNIST images: {e}")
            return None
        return None

    def create_rank_r2_thumbnail_figure(args, model_prefix, reps, labels, speakers, results_dir='03_results/models'):
        # Try to load rank history
        rank_file = os.path.join(results_dir, f"{model_prefix}_rank_history.csv")
        if not os.path.exists(rank_file):
            print(f"Rank history not found ({rank_file}); skipping ranks/R² figure.")
            return
        try:
            rank_history = pd.read_csv(rank_file)
        except Exception as e:
            print(f"Failed to read rank history {rank_file}: {e}")
            return

        # optionally append posttrain losses if present (for continuity)
        try:
            posttrain_file = os.path.join(results_dir, f"{model_prefix}_posttrain_losses.csv")
            if os.path.exists(posttrain_file):
                posttrain = pd.read_csv(posttrain_file)
                orig_max = rank_history['epoch'].max()
                post_epochs = range(orig_max + 1, orig_max + 1 + len(posttrain))
                post_df = pd.DataFrame({'epoch': post_epochs, 'loss': posttrain['train_loss'], 'val_loss': posttrain['val_loss']})
                # copy last rank/r2 values forward if present
                last = rank_history.iloc[-1]
                for c in ['total_rank', 'ranks']:
                    if c in rank_history.columns:
                        post_df[c] = last[c]
                for i in range(2):
                    r2c = f'rsquare {i}'
                    if r2c in rank_history.columns:
                        post_df[r2c] = last[r2c]
                rank_history = pd.concat([rank_history, post_df], ignore_index=True)
        except Exception:
            pass

        # parse individual ranks from string column 'ranks' if present
        try:
            rank_strings = rank_history['ranks'].astype(str).values
            individual_ranks = np.array([[int(x.strip()) for x in s.split(',')] for s in rank_strings])
        except Exception:
            individual_ranks = None

        # Prepare aligned MNIST thumbnails for the sample ordering (best-effort)
        aligned_images = load_aligned_mnist_images(model_prefix, results_dir, labels=labels)

        # Ensure labels and speakers are numpy arrays (caller may pass lists)
        try:
            labels = np.asarray(labels)
        except Exception:
            labels = np.array(labels)
        try:
            speakers = np.asarray(speakers)
        except Exception:
            speakers = np.array(speakers)

        # Normalize aligned_images to numpy array if possible to safely use .shape/.size
        if aligned_images is not None:
            try:
                aligned_images = np.asarray(aligned_images)
            except Exception:
                aligned_images = None

        # Apply paper-style rcParams (match 040_mm_paramsim_plots.py)
        plt.style.use('default')
        plt.rcParams.update({
            'font.size': 11,
            'axes.titlesize': 11,
            'axes.labelsize': 11,
            'xtick.labelsize': 10,
            'ytick.labelsize': 10,
            'legend.fontsize': 10,
            'figure.titlesize': 12,
            'font.family': 'sans-serif',
            'figure.dpi': 300,
            'savefig.dpi': 300,
            'savefig.bbox': 'tight',
            'savefig.pad_inches': 0.05,
            'axes.linewidth': 0.8,
            'grid.linewidth': 0.5,
        })

        # Color palette matching 040_mm_paramsim_plots.py
        color_palette_subspaces = {'Shared': '#4c72b3', 'Image': '#f48e56', 'Audio': '#ea8ab2'}

        # Build figure similar to 052: GridSpec 2 x 4
        fig = plt.figure(figsize=(9.45, 2.72))
        gs = gridspec.GridSpec(2, 4, figure=fig, hspace=0.3, wspace=0.3)

        # Ranks subplot (row 0, cols 0-1)
        ax_ranks = fig.add_subplot(gs[0, 0])
        epochs = rank_history['epoch'].values
        if individual_ranks is not None:
            ax_ranks.plot(epochs, individual_ranks[:, 0], color=color_palette_subspaces['Shared'], linewidth=1, label='Shared', marker='o', markersize=4)
            if individual_ranks.shape[1] > 1:
                ax_ranks.plot(epochs, individual_ranks[:, 1], color=color_palette_subspaces['Image'], linewidth=1, label='Image', marker='o', markersize=4)
            if individual_ranks.shape[1] > 2:
                ax_ranks.plot(epochs, individual_ranks[:, 2], color=color_palette_subspaces['Audio'], linewidth=1, label='Audio', marker='o', markersize=4)
        ax_ranks.set_xlabel('')
        ax_ranks.set_ylabel('Rank')
        # Use log scale for ranks to show broad dynamic range
        try:
            ax_ranks.set_yscale('log')
        except Exception:
            pass
        # only set xticks for values 1, 10, 50, 200
        try:
            ax_ranks.set_yticks([1, 10, 50, 200])
            ax_ranks.get_yaxis().set_major_formatter(plt.ScalarFormatter())
        except Exception:
            pass
        ax_ranks.set_title('Subspace Ranks')
        # only show the first legend item (Shared) if handles are present
        try:
            leg_handles, leg_labels = ax_ranks.get_legend_handles_labels()
            if leg_handles and len(leg_handles) > 0:
                ax_ranks.legend([leg_handles[0]], [leg_labels[0]], loc='upper right', fontsize=9, frameon=False)
        except Exception:
            print("Could not set legend for ranks subplot.")
            # fallback: don't set a legend if anything goes wrong
            pass
        ax_ranks.grid(True, alpha=0.3)
        # Do not show epoch xticks on the ranks subplot; keep its xlabel empty
        ax_ranks.set_xlabel('')
        # Explicitly remove xticks from ranks subplot
        try:
            ax_ranks.set_xticks([])
            ax_ranks.set_xticklabels([])
        except Exception:
            pass

        # R² subplot (row 1, cols 0-1)
        ax_r2 = fig.add_subplot(gs[1, 0])
        for i in range(2):
            r2_col = f'rsquare {i}'
            if r2_col in rank_history.columns:
                modality_name = 'Image' if i == 0 else 'Audio'
                color = color_palette_subspaces['Image'] if i == 0 else color_palette_subspaces['Audio']
                ax_r2.plot(epochs, rank_history[r2_col], color=color, linewidth=2, label=modality_name, marker='o', markersize=3)
                try:
                    ref_val = max(rank_history[r2_col]) - 0.05
                    ax_r2.axhline(ref_val, color=color, linestyle=':', linewidth=1, alpha=0.8)
                except Exception:
                    pass
        ax_r2.set_xlabel('Epoch')
        ax_r2.set_ylabel('R²')
        ax_r2.set_title('Distortion Metric')
        #ax_r2.legend().remove()
        ax_r2.legend(loc='lower right', fontsize=9, frameon=False)
        ax_r2.grid(True, alpha=0.3)
        # Set epoch ticks on the R² subplot every 100 epochs; prefer starting at 500 if present
        try:
            if len(epochs) > 0:
                ep_min = int(min(epochs))
                ep_max = int(max(epochs))
                start = 500 if 500 >= ep_min else ep_min
                # ensure start <= ep_max
                if start > ep_max:
                    start = ep_min
                tick_vals = np.arange(start, ep_max + 1, 100, dtype=int)
                if tick_vals.size == 0:
                    tick_vals = np.array([ep_min])
                ax_r2.set_xticks(tick_vals)
                ax_r2.set_xticklabels([str(int(v)) for v in tick_vals], rotation=0)
        except Exception:
            pass

        # Thumbnail/subspace columns (cols 2-4)
        subspace_names = ['Shared', 'Image', 'Audio']
        n_train = min(60000, labels.shape[0])
        for i in range(3):
            ax_rep = fig.add_subplot(gs[:, i+1])
            try:
                rep = reps[i]
                if rep is None:
                    ax_rep.text(0.5, 0.5, 'missing rep', ha='center')
                    continue
                rep_arr = np.asarray(rep)[:n_train]
                rep2d = reduce_to_2d(rep_arr, method='pca')
                # compute accuracies for title (raw rep)
                try:
                    digit_acc = compute_knn_acc(rep_arr, labels)
                except Exception:
                    digit_acc = None
                try:
                    speaker_acc = compute_knn_acc(rep_arr, speakers)
                except Exception:
                    speaker_acc = None
                # compute GaussianNB (and other) classifier results for digits to show in titles
                try:
                    cls_results = compute_multi_classifier_accs(rep_arr, labels)
                    gnb_digit_acc = cls_results.get('gnb') if cls_results is not None else None
                except Exception:
                    gnb_digit_acc = None

                # Prefer thumbnails if aligned images exist; attempt to wrap indices so thumbnails
                # are shown even when aligned_images length doesn't exactly match the rep length
                try:
                    max_samples = 3000
                    if rep_arr.shape[0] > max_samples:
                        indices = np.random.choice(rep_arr.shape[0], max_samples, replace=False)
                        rep_plot = rep_arr[indices]
                        rep2d_plot = reduce_to_2d(rep_plot, method='pca')
                        labels_plot = labels[indices]
                    else:
                        indices = np.arange(rep_arr.shape[0])
                        rep_plot = rep_arr
                        rep2d_plot = rep2d
                        labels_plot = labels[:len(rep2d_plot)]

                    images_plot = None
                    if aligned_images is not None and aligned_images.size > 0:
                        # ensure we can index into aligned_images by wrapping indices
                        try:
                            wrapped_idx = np.mod(indices, aligned_images.shape[0])
                            images_plot = aligned_images[wrapped_idx]
                        except Exception:
                            try:
                                images_plot = aligned_images[:len(indices)]
                            except Exception:
                                images_plot = None

                    # Ensure arrays are numpy arrays for safe shape/indexing
                    try:
                        rep_plot = np.asarray(rep_plot)
                    except Exception:
                        rep_plot = np.array(rep_plot)
                    try:
                        rep2d_plot = np.asarray(rep2d_plot)
                    except Exception:
                        rep2d_plot = np.array(rep2d_plot)
                    try:
                        labels_plot = np.asarray(labels_plot)
                    except Exception:
                        labels_plot = np.array(labels_plot)
                    if images_plot is not None:
                        try:
                            images_plot = np.asarray(images_plot)
                        except Exception:
                            images_plot = None

                    if images_plot is not None and images_plot.shape[0] >= rep2d_plot.shape[0]:
                        ax_rep.scatter(rep2d_plot[:, 0], rep2d_plot[:, 1], alpha=0)
                        n_thumbnails = min(500, len(rep2d_plot))
                        thumb_indices = np.random.choice(len(rep2d_plot), n_thumbnails, replace=False)
                        for ti in thumb_indices:
                            img = images_plot[ti]
                            try:
                                img2 = np.squeeze(img)
                                img2 = img2.astype('float32')
                                if img2.max() > 1.0:
                                    img2 = img2 / 255.0
                            except Exception:
                                img2 = np.asarray(img)
                            try:
                                imagebox = OffsetImage(img2, zoom=0.3, cmap='gray')
                                imagebox.set_zorder(10)
                                ab = AnnotationBbox(imagebox, (rep2d_plot[ti, 0], rep2d_plot[ti, 1]), frameon=False, pad=0)
                                ab.set_clip_on(False)
                                ax_rep.add_artist(ab)
                            except Exception:
                                ax_rep.plot(rep2d_plot[ti, 0], rep2d_plot[ti, 1], marker='o', markersize=1, color='k', alpha=0.3)
                        # ensure extents include plotted points
                        try:
                            x_min, x_max = rep2d_plot[:, 0].min(), rep2d_plot[:, 0].max()
                            y_min, y_max = rep2d_plot[:, 1].min(), rep2d_plot[:, 1].max()
                            xpad = (x_max - x_min) * 0.02 if (x_max - x_min) > 0 else 1.0
                            ypad = (y_max - y_min) * 0.02 if (y_max - y_min) > 0 else 1.0
                            ax_rep.set_xlim(x_min - xpad, x_max + xpad)
                            ax_rep.set_ylim(y_min - ypad, y_max + ypad)
                        except Exception:
                            pass
                        print(f"DEBUG: Added {n_thumbnails} thumbnails to subspace {i}")
                    else:
                        ax_rep.scatter(rep2d[:, 0], rep2d[:, 1], c=labels[:rep2d.shape[0]], cmap='tab10', s=3, alpha=0.7)
                except Exception:
                    ax_rep.scatter(rep2d[:, 0], rep2d[:, 1], c=labels[:rep2d.shape[0]], cmap='tab10', s=3, alpha=0.7)

                # Title includes digit accuracy and GaussianNB digit accuracy when available
                title = subspace_names[i]
                #if digit_acc is not None:
                #    title = f"{title} (Acc {digit_acc:.2f})"
                if gnb_digit_acc is not None:
                    #title = f"{title} | GNB: {gnb_digit_acc:.2f}"
                    title = f"{title} (Acc {gnb_digit_acc:.2f})"
                ax_rep.set_title(title)
                # Keep labels for clarity but remove ticks and ticklabels for a cleaner thumbnail presentation
                ax_rep.set_xlabel('PC 1')
                ax_rep.set_ylabel('PC 2')
                try:
                    ax_rep.set_xticks([])
                    ax_rep.set_yticks([])
                    ax_rep.set_xticklabels([])
                    ax_rep.set_yticklabels([])
                except Exception:
                    pass
            except Exception as e:
                ax_rep.text(0.5, 0.5, f'error: {e}', ha='center')

        # Add A/B annotations on the left (match style from 040_mm_paramsim_plots_v2.py)
        try:
            fig.text(0.07, 0.95, 'A', fontsize=14, fontweight='bold', ha='center', va='center')
            fig.text(0.30, 0.95, 'B', fontsize=14, fontweight='bold', ha='center', va='center')
        except Exception:
            pass

        # Save second figure
        out_dir = '03_results/plots'
        os.makedirs(out_dir, exist_ok=True)
        out_file2 = os.path.join(out_dir, f'audiomnist2_ranks_r2_thumbnails_rseed-{args.seed}{"_fullspec" if args.full_spectrum else ""}.png')
        plt.savefig(out_file2, dpi=300, bbox_inches='tight')
        print(f'Saved ranks/R²/thumbnail figure to: {out_file2}')

    # Attempt to create the ranks/R² + thumbnails figure (best-effort)
    try:
        create_rank_r2_thumbnail_figure(args, model_prefix, reps, labels_concat, speakers_concat, results_dir=args.results_dir)
    except Exception as e:
        print(f'Warning: could not create ranks/R² thumbnail figure: {e}')

    # Save accuracies (include both raw and 2D accuracies)
    out_csv_dir = '03_results/processed'
    os.makedirs(out_csv_dir, exist_ok=True)
    import csv
    csv_file = os.path.join(out_csv_dir, f'audiomnist2_rep_accuracies_rseed-{args.seed}{"_fullspec" if args.full_spectrum else ""}.csv')
    fieldnames = [
        'rep',
        'digit_acc_raw', 'speaker_acc_raw', 'gender_acc_raw', 'native_acc_raw', 'accent_acc_raw',
        'origin_acc_raw', 'room_acc_raw',
        'digit_acc_2d', 'speaker_acc_2d', 'gender_acc_2d', 'native_acc_2d', 'accent_acc_2d', 'origin_acc_2d', 'room_acc_2d'
    ]
    with open(csv_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        if not accs:
            print('Warning: no accuracy records to write (accs is empty)')
        for row in accs:
            out = {k: row.get(k, None) for k in fieldnames}
            writer.writerow(out)
    print(f'Saved accuracies to: {csv_file}')
    # Print a short console summary of accuracies
    if accs:
        print('\nAccuracy summary:')
        print(accs[0])
        print(accs[1])
        print(accs[2])

    # Save classifier comparisons CSV
    comp_csv = os.path.join(out_csv_dir, f'audiomnist2_rep_classifier_comparisons_rseed-{args.seed}{"_fullspec" if args.full_spectrum else ""}.csv')
    comp_fieldnames = ['rep', 'label_type', 'knn_acc', 'logistic_acc', 'gnb_acc']
    try:
        with open(comp_csv, 'w', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=comp_fieldnames)
            writer.writeheader()
            for r in comp_rows:
                writer.writerow(r)
        print(f'Saved classifier comparisons to: {comp_csv}')
    except Exception as e:
        print(f'Failed to save classifier comparisons CSV: {e}')


if __name__ == '__main__':
    main()
