#!/usr/bin/env python3
"""Combined analysis for AVMNIST run-2 results.

This script merges functionality from 053_audiomnist_analysis.py and
053_audiomnist_testplots.py. It loads representations saved as
03_results/models/{prefix}_rep{i}.npy, reconstructs matched labels (same
procedure as 056_audiomnist.py), computes classification accuracies (KNN +
multi-classifier comparisons), creates 2D scatterplots (PCA) colored by digit
and speaker, and — if rank history is available — creates a ranks / R² /
thumbnail figure.

Outputs are written with a _run2 suffix so earlier results are not overwritten.
"""
import os
import argparse
import glob
import numpy as np
import matplotlib.pyplot as plt
import csv
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
import pandas as pd

# Small helpers (extracted/condensed from the originals)

def find_rep_files(prefix, results_dir='03_results/models'):
    pattern = os.path.join(results_dir, f"{prefix}_rep*.npy")
    files = sorted(glob.glob(pattern))
    return files


def load_labels(data_dir):
    train_labels = np.load(os.path.join(data_dir, 'train_labels.npy'))
    test_labels = np.load(os.path.join(data_dir, 'test_labels.npy'))
    # speaker labels may exist
    train_speakers = None
    test_speakers = None
    sp_train_path = os.path.join(data_dir, 'train_speaker_labels.npy')
    sp_test_path = os.path.join(data_dir, 'test_speaker_labels.npy')
    if os.path.exists(sp_train_path):
        train_speakers = np.load(sp_train_path)
    if os.path.exists(sp_test_path):
        test_speakers = np.load(sp_test_path)
    return train_labels, test_labels, train_speakers, test_speakers


def prep_2d(rep, use_pca=True):
    if rep.shape[1] <= 2 and not use_pca:
        if rep.shape[1] == 1:
            return np.concatenate([rep, np.zeros((rep.shape[0], 1))], axis=1)
        return rep[:, :2]
    if rep.shape[1] <= 2:
        if rep.shape[1] == 1:
            return np.concatenate([rep, np.zeros((rep.shape[0], 1))], axis=1)
        return rep
    pca = PCA(n_components=2)
    return pca.fit_transform(rep)


def compute_knn_acc(rep, labels, n_neighbors=1):
    if rep is None or labels is None:
        return None
    try:
        X = rep
        y = np.asarray(labels)
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 5:
            return None
        split = int(0.8 * n)
        knn = KNeighborsClassifier(n_neighbors=n_neighbors)
        knn.fit(X[:split], y[:split])
        acc = knn.score(X[split:], y[split:])
        return float(acc)
    except Exception:
        return None


def compute_multi_classifier_accs(rep, labels):
    out = {'knn': None, 'logistic': None, 'gnb': None}
    if rep is None or labels is None:
        return out
    try:
        X = np.asarray(rep)
        y = np.asarray(labels)
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 10:
            return out
        split = int(0.8 * n)
        X_train, X_test = X[:split], X[split:]
        y_train, y_test = y[:split], y[split:]
        # KNN
        try:
            knn = KNeighborsClassifier(n_neighbors=1)
            knn.fit(X_train, y_train)
            out['knn'] = float(knn.score(X_test, y_test))
        except Exception:
            out['knn'] = None
        # Logistic
        try:
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            try:
                log = LogisticRegression(max_iter=2000, solver='saga', multi_class='multinomial')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
            except Exception:
                log = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='auto')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
        except Exception:
            out['logistic'] = None
        # Gaussian NB
        try:
            scaler = StandardScaler()
            X_train_g = scaler.fit_transform(X_train)
            X_test_g = scaler.transform(X_test)
            gnb = GaussianNB()
            gnb.fit(X_train_g, y_train)
            out['gnb'] = float(gnb.score(X_test_g, y_test))
        except Exception:
            out['gnb'] = None
        return out
    except Exception:
        return out


def plot_scatter(X2, labels, title, out_path, cmap='tab10', colorbar=False):
    plt.figure(figsize=(6, 5))
    sc = plt.scatter(X2[:, 0], X2[:, 1], c=labels, s=6, cmap=cmap, alpha=0.8)
    plt.title(title)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    if colorbar:
        plt.colorbar(sc)
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()


def try_create_ranks_r2_figure(prefix, reps, labels, speakers, results_dir='03_results/models', out_dir='03_results/plots', seed=42, full_spectrum=False):
    # look for rank history file for the given prefix or alternative prefix with '2'
    rank_file = os.path.join(results_dir, f"{prefix}_rank_history.csv")
    if not os.path.exists(rank_file):
        # try alternative audiomnist2 prefix
        alt = prefix.replace('audiomnist', 'audiomnist2')
        rank_file = os.path.join(results_dir, f"{alt}_rank_history.csv")
        if not os.path.exists(rank_file):
            print(f"Rank history not found for {prefix}; skipping ranks/R² figure.")
            return
    try:
        rank_history = pd.read_csv(rank_file)
    except Exception as e:
        print(f"Failed to read rank history {rank_file}: {e}")
        return

    # attempt to build a compact figure (ranks + R²) and save with _run2 suffix
    plt.style.use('default')
    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    # plot total rank or individual ranks if present
    if 'ranks' in rank_history.columns:
        try:
            rank_strings = rank_history['ranks'].astype(str).values
            individual_ranks = np.array([[int(x.strip()) for x in s.split(',')] for s in rank_strings])
            epochs = rank_history['epoch'].values
            axes[0].plot(epochs, individual_ranks[:, 0], label='Shared')
            if individual_ranks.shape[1] > 1:
                axes[0].plot(epochs, individual_ranks[:, 1], label='Image')
            if individual_ranks.shape[1] > 2:
                axes[0].plot(epochs, individual_ranks[:, 2], label='Audio')
            axes[0].set_ylabel('Rank')
            axes[0].set_xlabel('Epoch')
            axes[0].set_yscale('log')
            axes[0].legend()
        except Exception:
            pass
    else:
        if 'total_rank' in rank_history.columns:
            axes[0].plot(rank_history['epoch'].values, rank_history['total_rank'].values)
            axes[0].set_ylabel('Total rank')
            axes[0].set_xlabel('Epoch')

    # R² plot: look for columns 'rsquare 0' and 'rsquare 1'
    plotted = False
    for i, name in enumerate(['Image', 'Audio']):
        col = f'rsquare {i}'
        if col in rank_history.columns:
            axes[1].plot(rank_history['epoch'].values, rank_history[col].values, label=name)
            plotted = True
    if plotted:
        axes[1].set_ylabel('R²')
        axes[1].set_xlabel('Epoch')
        axes[1].legend()

    out_file = os.path.join(out_dir, f'{prefix}_ranks_r2_run2_rseed-{seed}{"_fullspec" if full_spectrum else ""}.png')
    os.makedirs(out_dir, exist_ok=True)
    plt.tight_layout()
    plt.savefig(out_file, dpi=200)
    plt.close()
    print(f"Saved ranks/R² figure to: {out_file}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model-prefix', default=None, help='Prefix used when saving reps in 03_results/models (e.g. audiomnist_fullspec_rseed-42)')
    parser.add_argument('--data-dir', default='01_data/avmnist_data_from_source', help='Where the train/test labels live')
    parser.add_argument('--results-dir', default='03_results/models', help='Where rep files and rank history live')
    parser.add_argument('--out-dir', default='03_results/plots', help='Where to save plots')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--full_spectrum', action='store_true')
    args = parser.parse_args()

    # If model_prefix not provided, try likely candidates
    if args.model_prefix is None:
        base = f'audiomnist{"_fullspec" if args.full_spectrum else ""}_rseed-{args.seed}'
        alt = f'audiomnist2{"_fullspec" if args.full_spectrum else ""}_rseed-{args.seed}'
        # prefer the base name (matches how training saved files earlier)
        candidate = base if os.path.exists(os.path.join(args.results_dir, base + '_rep0.npy')) else alt
        model_prefix = candidate
    else:
        model_prefix = args.model_prefix

    print(f"Using model prefix: {model_prefix}")

    # Find rep files (allow up to 3 reps)
    rep_files = find_rep_files(model_prefix, results_dir=args.results_dir)
    if len(rep_files) == 0:
        print(f"No rep files found for prefix {model_prefix}. Try the alternative 'audiomnist2' prefix or pass --model-prefix.")
        return

    # Load labels
    train_labels, test_labels, train_speakers, test_speakers = load_labels(args.data_dir)
    n_train = len(train_labels)
    all_labels = np.concatenate([train_labels, test_labels])
    all_speakers = None
    if (train_speakers is not None) and (test_speakers is not None):
        all_speakers = np.concatenate([train_speakers, test_speakers])

    # Process each rep file: plot and compute accuracies
    results = []
    comp_rows = []
    for rep_path in rep_files:
        rep = np.load(rep_path)
        name = os.path.basename(rep_path).replace('.npy', '')
        print(f'Processing {name} with shape {rep.shape}')
        # split
        X_train = rep[:n_train]
        X_test = rep[n_train:]
        # prepare 2D for plotting
        X2 = prep_2d(np.concatenate([X_train, X_test]), use_pca=True)
        # digits plot
        out_plot_digits = os.path.join(args.out_dir, f'{name}_digits_run2.png')
        plot_scatter(X2, np.concatenate([train_labels, test_labels]), f'{name} colored by digit', out_plot_digits, cmap='tab10', colorbar=False)
        # speakers plot
        if all_speakers is not None:
            out_plot_speakers = os.path.join(args.out_dir, f'{name}_speakers_run2.png')
            plot_scatter(X2, all_speakers, f'{name} colored by speaker', out_plot_speakers, cmap='viridis', colorbar=True)
        # classification on original rep
        try:
            acc_digit = compute_knn_acc(rep, np.concatenate([train_labels, test_labels]))
        except Exception:
            acc_digit = None
        if all_speakers is not None:
            try:
                acc_speaker = compute_knn_acc(rep, all_speakers)
            except Exception:
                acc_speaker = None
        else:
            acc_speaker = None
        print(f'  {name}: digit acc={acc_digit}, speaker acc={acc_speaker}')
        results.append({'rep': name, 'dim': rep.shape[1], 'digit_acc': acc_digit, 'speaker_acc': acc_speaker})
        # classifier comparisons
        try:
            accs_digit = compute_multi_classifier_accs(rep, np.concatenate([train_labels, test_labels]))
        except Exception:
            accs_digit = {'knn': None, 'logistic': None, 'gnb': None}
        comp_rows.append({'rep': name, 'label_type': 'digits', 'knn_acc': accs_digit.get('knn'), 'logistic_acc': accs_digit.get('logistic'), 'gnb_acc': accs_digit.get('gnb')})
        if all_speakers is not None:
            try:
                accs_speaker = compute_multi_classifier_accs(rep, all_speakers)
            except Exception:
                accs_speaker = {'knn': None, 'logistic': None, 'gnb': None}
            comp_rows.append({'rep': name, 'label_type': 'speakers', 'knn_acc': accs_speaker.get('knn'), 'logistic_acc': accs_speaker.get('logistic'), 'gnb_acc': accs_speaker.get('gnb')})

    # Save rep accuracies CSV with _run2 suffix
    os.makedirs(args.out_dir, exist_ok=True)
    df = pd.DataFrame(results)
    out_csv = os.path.join(args.out_dir, f'{model_prefix}_rep_accuracies_run2.csv')
    df.to_csv(out_csv, index=False)
    print(f'Saved accuracies to {out_csv}')

    # Save classifier comparisons CSV with _run2 suffix
    comp_df = pd.DataFrame(comp_rows)
    comp_out = os.path.join(args.out_dir, f'{model_prefix}_rep_classifier_comparisons_run2.csv')
    comp_df.to_csv(comp_out, index=False)
    print(f'Saved classifier comparisons to {comp_out}')

    # Try to create ranks/R² figure (best-effort)
    try_create_ranks_r2_figure(model_prefix, rep_files, np.concatenate([train_labels, test_labels]), all_speakers, results_dir=args.results_dir, out_dir=args.out_dir, seed=args.seed, full_spectrum=args.full_spectrum)


if __name__ == '__main__':
    main()
