#!/usr/bin/env python3
"""Combined analysis for AVMNIST pretrained model results.

This script is adapted from 054_audiomnist_run2_plots.py to work with models
trained using pretrained encoders/decoders (056_train_figuro_from_pretrained.py).

It loads representations saved as 03_results/models/{prefix}_rep{i}.npy,
reconstructs matched labels (same procedure as 056_audiomnist.py), computes
classification accuracies (KNN + multi-classifier comparisons), creates 2D
scatterplots (PCA) colored by digit and speaker, and — if rank history is
available — creates a ranks / R² / thumbnail figure.

Outputs are written with a _run2 suffix so earlier results are not overwritten.
"""
import os
import argparse
import glob
import numpy as np
import matplotlib.pyplot as plt
import csv
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
import pandas as pd

# Small helpers (extracted/condensed from the originals)

def find_rep_files(prefix, results_dir='03_results/models'):
    pattern = os.path.join(results_dir, f"{prefix}_rep*.npy")
    files = sorted(glob.glob(pattern))
    return files


def load_labels(data_dir):
    train_labels = np.load(os.path.join(data_dir, 'train_labels.npy'))
    test_labels = np.load(os.path.join(data_dir, 'test_labels.npy'))
    # speaker labels may exist
    train_speakers = None
    test_speakers = None
    sp_train_path = os.path.join(data_dir, 'train_speaker_labels.npy')
    sp_test_path = os.path.join(data_dir, 'test_speaker_labels.npy')
    if os.path.exists(sp_train_path):
        train_speakers = np.load(sp_train_path)
    if os.path.exists(sp_test_path):
        test_speakers = np.load(sp_test_path)
    return train_labels, test_labels, train_speakers, test_speakers


def match_datasets_by_label_simple(audio_labels, speaker_labels, n_samples_per_class=6000, seed=42):
    """Simplified version that returns matched labels and speakers (same as training script)"""
    np.random.seed(seed)
    matched_labels = []
    matched_speaker_labels = []
    for digit in range(10):
        audio_indices = np.where(audio_labels == digit)[0]
        n_samples = min(len(audio_indices), n_samples_per_class)
        if n_samples > 0:
            audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
            matched_labels.extend([digit] * n_samples)
            matched_speaker_labels.append(speaker_labels[audio_sample_indices])
    return np.array(matched_labels), np.concatenate(matched_speaker_labels, axis=0)


def prep_2d(rep, use_pca=True):
    if rep.shape[1] <= 2 and not use_pca:
        if rep.shape[1] == 1:
            return np.concatenate([rep, np.zeros((rep.shape[0], 1))], axis=1)
        return rep[:, :2]
    if rep.shape[1] <= 2:
        if rep.shape[1] == 1:
            return np.concatenate([rep, np.zeros((rep.shape[0], 1))], axis=1)
        return rep
    pca = PCA(n_components=2)
    return pca.fit_transform(rep)


def compute_knn_acc(rep, labels, n_neighbors=1):
    if rep is None or labels is None:
        return None
    try:
        X = rep
        y = np.asarray(labels)
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 5:
            return None
        split = int(0.8 * n)
        knn = KNeighborsClassifier(n_neighbors=n_neighbors)
        knn.fit(X[:split], y[:split])
        acc = knn.score(X[split:], y[split:])
        return float(acc)
    except Exception:
        return None


def compute_multi_classifier_accs(rep, labels):
    out = {'knn': None, 'logistic': None, 'gnb': None}
    if rep is None or labels is None:
        return out
    try:
        X = np.asarray(rep)
        y = np.asarray(labels)
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 10:
            return out
        split = int(0.8 * n)
        X_train, X_test = X[:split], X[split:]
        y_train, y_test = y[:split], y[split:]
        # KNN
        try:
            knn = KNeighborsClassifier(n_neighbors=1)
            knn.fit(X_train, y_train)
            out['knn'] = float(knn.score(X_test, y_test))
        except Exception:
            out['knn'] = None
        # Logistic
        try:
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            try:
                log = LogisticRegression(max_iter=2000, solver='saga', multi_class='multinomial')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
            except Exception:
                log = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='auto')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
        except Exception:
            out['logistic'] = None
        # Gaussian NB
        try:
            scaler = StandardScaler()
            X_train_g = scaler.fit_transform(X_train)
            X_test_g = scaler.transform(X_test)
            gnb = GaussianNB()
            gnb.fit(X_train_g, y_train)
            out['gnb'] = float(gnb.score(X_test_g, y_test))
        except Exception:
            out['gnb'] = None
        return out
    except Exception:
        return out


def plot_scatter(X2, labels, title, out_path, cmap='tab10', colorbar=False):
    plt.figure(figsize=(6, 5))
    sc = plt.scatter(X2[:, 0], X2[:, 1], c=labels, s=6, cmap=cmap, alpha=0.8)
    plt.title(title)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    if colorbar:
        plt.colorbar(sc)
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()


def try_create_ranks_r2_figure(prefix, reps, labels, speakers, results_dir='03_results/models', out_dir='03_results/plots', seed=42, full_spectrum=False):
    # look for rank history file for the given prefix
    rank_file = os.path.join(results_dir, f"{prefix}_rank_history.csv")
    if not os.path.exists(rank_file):
        print(f"Rank history not found for {prefix}; skipping ranks/R² figure.")
        return
    try:
        rank_history = pd.read_csv(rank_file)
    except Exception as e:
        print(f"Failed to read rank history {rank_file}: {e}")
        return

    # attempt to build a compact figure (ranks + R²) and save with _run2 suffix
    plt.style.use('default')
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    # plot total rank or individual ranks if present
    if 'ranks' in rank_history.columns:
        try:
            rank_strings = rank_history['ranks'].astype(str).values
            individual_ranks = np.array([[int(x.strip()) for x in s.split(',')] for s in rank_strings])
            epochs = rank_history['epoch'].values
            axes[0].plot(epochs, individual_ranks[:, 0], label='Shared')
            if individual_ranks.shape[1] > 1:
                axes[0].plot(epochs, individual_ranks[:, 1], label='Image')
            if individual_ranks.shape[1] > 2:
                axes[0].plot(epochs, individual_ranks[:, 2], label='Audio')
            axes[0].set_ylabel('Rank')
            axes[0].set_xlabel('Epoch')
            axes[0].set_yscale('log')
            axes[0].legend()
        except Exception:
            pass
    else:
        if 'total_rank' in rank_history.columns:
            axes[0].plot(rank_history['epoch'].values, rank_history['total_rank'].values)
            axes[0].set_ylabel('Total rank')
            axes[0].set_xlabel('Epoch')

    # R² plot: look for columns 'rsquare 0' and 'rsquare 1'
    plotted = False
    for i, name in enumerate(['Image', 'Audio']):
        col = f'rsquare {i}'
        if col in rank_history.columns:
            axes[1].plot(rank_history['epoch'].values, rank_history[col].values, label=name)
            plotted = True
    if plotted:
        axes[1].set_ylabel('R²')
        axes[1].set_xlabel('Epoch')
        axes[1].legend()

    out_file = os.path.join(out_dir, f'{prefix}_ranks_r2_run2_rseed-{seed}{"_fullspec" if full_spectrum else ""}.png')
    os.makedirs(out_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(out_file, dpi=200)
    plt.close()
    print(f"Saved ranks/R² figure to: {out_file}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-prefix', default=None, help='Prefix used when saving reps in 03_results/models (e.g. audiomnist_fullspec_pretrained_rseed-0)')
    parser.add_argument('--data-dir', default='01_data/avmnist_data_from_source', help='Where the train/test labels live')
    parser.add_argument('--results-dir', default='03_results/models', help='Where rep files and rank history live')
    parser.add_argument('--out-dir', default='03_results/plots', help='Where to save plots')
    parser.add_argument('--seed', type=int, default=0, help='Random seed (matches training)')
    parser.add_argument('--full_spectrum', action='store_true')
    parser.add_argument('--frozen', action='store_true', help='Add _frozen suffix if model was trained with frozen pretrained components')
    args = parser.parse_args()

    # If model_prefix not provided, construct it based on training script naming convention
    if args.model_prefix is None:
        model_prefix = f'audiomnist{"_fullspec" if args.full_spectrum else ""}_pretrained_rseed-{args.seed}'
        if args.frozen:
            model_prefix += '_frozen'
    else:
        model_prefix = args.model_prefix

    print(f"Using model prefix: {model_prefix}")

    # Find rep files (allow up to 3 reps)
    rep_files = find_rep_files(model_prefix, results_dir=args.results_dir)
    if len(rep_files) == 0:
        print(f"No rep files found for prefix {model_prefix}.")
        print(f"Expected files like: {args.results_dir}/{model_prefix}_rep0.npy")
        return

    # Load source labels (from original data, not matched yet)
    data_dir = args.data_dir
    train_labels_src = np.load(os.path.join(data_dir, 'train_labels.npy'))
    train_speakers_src = np.load(os.path.join(data_dir, 'train_speaker_labels.npy'))
    test_labels_src = np.load(os.path.join(data_dir, 'test_labels.npy'))
    test_speakers_src = np.load(os.path.join(data_dir, 'test_speaker_labels.npy'))

    # Match using same parameters as training (5000 train, 1000 test per digit)
    matched_train_labels, matched_train_speakers = match_datasets_by_label_simple(
        train_labels_src, train_speakers_src, n_samples_per_class=5000, seed=args.seed
    )
    matched_test_labels, matched_test_speakers = match_datasets_by_label_simple(
        test_labels_src, test_speakers_src, n_samples_per_class=1000, seed=args.seed
    )

    # Concatenate to match training data ordering (train then test)
    all_labels = np.concatenate([matched_train_labels, matched_test_labels])
    all_speakers = np.concatenate([matched_train_speakers, matched_test_speakers])
    n_train = len(matched_train_labels)

    print(f"Reconstructed {len(all_labels)} samples ({n_train} train, {len(all_labels)-n_train} test)")

    # --- Load metadata and map speaker IDs to additional labels ---
    def load_meta_candidates(base_dirs):
        """Try several locations for audioMNIST_meta.txt/json"""
        for base in base_dirs:
            for fname in ('audioMNIST_meta.txt', 'audioMNIST_meta.json', 'audioMNIST_meta'):
                p = os.path.join(base, fname)
                if os.path.exists(p):
                    try:
                        import json
                        with open(p, 'r') as f:
                            return json.load(f)
                    except:
                        pass
        return None
    
    meta = load_meta_candidates([data_dir, '01_data/processed/avmnist', 'AudioMNIST', '01_data/processed/avmnist'])
    if meta is None:
        print('Warning: audioMNIST metadata not found. Gender/native/accent/origin/room labels will be unavailable.')
    
    def map_meta_field(speakers_array, field):
        """Map speaker IDs to metadata field values"""
        if meta is None:
            return None
        out = []
        for s in speakers_array:
            candidates = [str(s), str(int(s) + 1)]
            val = None
            for c in candidates:
                if c in meta:
                    val = meta[c].get(field)
                    break
            if val is None:
                try:
                    val = meta.get(int(s), {}).get(field)
                except:
                    pass
            out.append(val if val is not None else 'UNKNOWN')
        return np.array(out, dtype=object)
    
    # Extract all metadata fields
    gender_all = map_meta_field(all_speakers, 'gender')
    native_all = map_meta_field(all_speakers, 'native speaker')
    accent_all = map_meta_field(all_speakers, 'accent')
    origin_all = map_meta_field(all_speakers, 'origin')
    
    # Extract continent from origin (first word)
    if origin_all is not None:
        origin_all = np.array([str(o).split()[0] if o != 'UNKNOWN' else o for o in origin_all], dtype=object)
    
    room_all = map_meta_field(all_speakers, 'recordingroom')
    
    # Print brief metadata summary
    if gender_all is not None:
        print(f'  Gender unique: {np.unique(gender_all)}')
    if native_all is not None:
        print(f'  Native speaker unique: {np.unique(native_all)}')
    if origin_all is not None:
        print(f'  Origin unique: {np.unique(origin_all)}')
    if room_all is not None:
        print(f'  Room unique: {np.unique(room_all)}')

    # Process each rep file: plot and compute accuracies
    results = []
    comp_rows = []
    for rep_path in rep_files:
        rep = np.load(rep_path)
        name = os.path.basename(rep_path).replace('.npy', '')
        print(f'Processing {name} with shape {rep.shape}')
        
        # Verify rep length matches labels
        if rep.shape[0] != len(all_labels):
            print(f'  Warning: rep has {rep.shape[0]} samples but labels have {len(all_labels)}')
            # Truncate to minimum
            N = min(rep.shape[0], len(all_labels))
            rep = rep[:N]
            labels_used = all_labels[:N]
            speakers_used = all_speakers[:N]
        else:
            labels_used = all_labels
            speakers_used = all_speakers
        
        # prepare 2D for plotting
        X2 = prep_2d(rep, use_pca=True)
        # digits plot
        out_plot_digits = os.path.join(args.out_dir, f'{name}_digits_run2.png')
        plot_scatter(X2, labels_used, f'{name} colored by digit', out_plot_digits, cmap='tab10', colorbar=False)
        # speakers plot
        out_plot_speakers = os.path.join(args.out_dir, f'{name}_speakers_run2.png')
        plot_scatter(X2, speakers_used, f'{name} colored by speaker', out_plot_speakers, cmap='viridis', colorbar=True)
        
        # classification on original rep
        try:
            acc_digit = compute_knn_acc(rep, labels_used)
        except Exception:
            acc_digit = None
        try:
            acc_speaker = compute_knn_acc(rep, speakers_used)
        except Exception:
            acc_speaker = None
        
        # Compute accuracies for all metadata fields
        acc_gender = compute_knn_acc(rep, gender_all[:len(rep)]) if gender_all is not None else None
        acc_native = compute_knn_acc(rep, native_all[:len(rep)]) if native_all is not None else None
        acc_accent = compute_knn_acc(rep, accent_all[:len(rep)]) if accent_all is not None else None
        acc_origin = compute_knn_acc(rep, origin_all[:len(rep)]) if origin_all is not None else None
        acc_room = compute_knn_acc(rep, room_all[:len(rep)]) if room_all is not None else None
        
        # Format accuracies for printing
        digit_str = f"{acc_digit:.3f}" if acc_digit is not None else "N/A"
        speaker_str = f"{acc_speaker:.3f}" if acc_speaker is not None else "N/A"
        print(f'  {name}: digit={digit_str}, speaker={speaker_str}')
        if acc_gender is not None:
            gender_str = f"{acc_gender:.3f}" if acc_gender is not None else "N/A"
            native_str = f"{acc_native:.3f}" if acc_native is not None else "N/A"
            origin_str = f"{acc_origin:.3f}" if acc_origin is not None else "N/A"
            room_str = f"{acc_room:.3f}" if acc_room is not None else "N/A"
            print(f'         gender={gender_str}, native={native_str}, origin={origin_str}, room={room_str}')
        
        results.append({
            'rep': name, 
            'dim': rep.shape[1], 
            'digit_acc': acc_digit, 
            'speaker_acc': acc_speaker,
            'gender_acc': acc_gender,
            'native_acc': acc_native,
            'accent_acc': acc_accent,
            'origin_acc': acc_origin,
            'room_acc': acc_room
        })
        
        # classifier comparisons
        try:
            accs_digit = compute_multi_classifier_accs(rep, labels_used)
        except Exception:
            accs_digit = {'knn': None, 'logistic': None, 'gnb': None}
        comp_rows.append({'rep': name, 'label_type': 'digits', 'knn_acc': accs_digit.get('knn'), 'logistic_acc': accs_digit.get('logistic'), 'gnb_acc': accs_digit.get('gnb')})
        
        try:
            accs_speaker = compute_multi_classifier_accs(rep, speakers_used)
        except Exception:
            accs_speaker = {'knn': None, 'logistic': None, 'gnb': None}
        comp_rows.append({'rep': name, 'label_type': 'speakers', 'knn_acc': accs_speaker.get('knn'), 'logistic_acc': accs_speaker.get('logistic'), 'gnb_acc': accs_speaker.get('gnb')})
        
        # Add metadata classifier comparisons
        if gender_all is not None:
            try:
                accs_gender = compute_multi_classifier_accs(rep, gender_all[:len(rep)])
            except Exception:
                accs_gender = {'knn': None, 'logistic': None, 'gnb': None}
            comp_rows.append({'rep': name, 'label_type': 'gender', 'knn_acc': accs_gender.get('knn'), 'logistic_acc': accs_gender.get('logistic'), 'gnb_acc': accs_gender.get('gnb')})
        
        if native_all is not None:
            try:
                accs_native = compute_multi_classifier_accs(rep, native_all[:len(rep)])
            except Exception:
                accs_native = {'knn': None, 'logistic': None, 'gnb': None}
            comp_rows.append({'rep': name, 'label_type': 'native', 'knn_acc': accs_native.get('knn'), 'logistic_acc': accs_native.get('logistic'), 'gnb_acc': accs_native.get('gnb')})
        
        if accent_all is not None:
            try:
                accs_accent = compute_multi_classifier_accs(rep, accent_all[:len(rep)])
            except Exception:
                accs_accent = {'knn': None, 'logistic': None, 'gnb': None}
            comp_rows.append({'rep': name, 'label_type': 'accent', 'knn_acc': accs_accent.get('knn'), 'logistic_acc': accs_accent.get('logistic'), 'gnb_acc': accs_accent.get('gnb')})
        
        if origin_all is not None:
            try:
                accs_origin = compute_multi_classifier_accs(rep, origin_all[:len(rep)])
            except Exception:
                accs_origin = {'knn': None, 'logistic': None, 'gnb': None}
            comp_rows.append({'rep': name, 'label_type': 'origin', 'knn_acc': accs_origin.get('knn'), 'logistic_acc': accs_origin.get('logistic'), 'gnb_acc': accs_origin.get('gnb')})
        
        if room_all is not None:
            try:
                accs_room = compute_multi_classifier_accs(rep, room_all[:len(rep)])
            except Exception:
                accs_room = {'knn': None, 'logistic': None, 'gnb': None}
            comp_rows.append({'rep': name, 'label_type': 'room', 'knn_acc': accs_room.get('knn'), 'logistic_acc': accs_room.get('logistic'), 'gnb_acc': accs_room.get('gnb')})

    # Save rep accuracies CSV with _run2 suffix
    os.makedirs(args.out_dir, exist_ok=True)
    df = pd.DataFrame(results)
    out_csv = os.path.join(args.out_dir, f'{model_prefix}_rep_accuracies_run2.csv')
    df.to_csv(out_csv, index=False)
    print(f'Saved accuracies to {out_csv}')

    # Save classifier comparisons CSV with _run2 suffix
    comp_df = pd.DataFrame(comp_rows)
    comp_out = os.path.join(args.out_dir, f'{model_prefix}_rep_classifier_comparisons_run2.csv')
    comp_df.to_csv(comp_out, index=False)
    print(f'Saved classifier comparisons to {comp_out}')

    # Try to create ranks/R² figure (best-effort)
    try_create_ranks_r2_figure(model_prefix, rep_files, all_labels, all_speakers, results_dir=args.results_dir, out_dir=args.out_dir, seed=args.seed, full_spectrum=args.full_spectrum)


if __name__ == '__main__':
    main()
