#!/usr/bin/env python3
"""Plot latent representations and compute classification accuracy.

Loads representation files saved as `03_results/models/{prefix}_rep{i}.npy` (all i),
loads labels from `01_data/avmnist_data_from_source/{train,test}_labels.npy` and
`{train,test}_speaker_labels.npy`, then for each representation:
- if dim > 2, reduce to 2D with PCA
- create scatter plots colored by digit and by speaker
- train a logistic regression on the training portion of the reps and evaluate on the test portion
- save plots to `03_results/plots` and accuracies to CSV

Usage:
  python 02_paper_experiments/analysis/053_audiomnist_testplots.py --model-prefix audiomnist2_rseed-42
"""
import os
import glob
import argparse
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
import pandas as pd


def load_labels(data_dir):
    train_labels = np.load(os.path.join(data_dir, 'train_labels.npy'))
    test_labels = np.load(os.path.join(data_dir, 'test_labels.npy'))
    # speaker labels may exist
    train_speakers = None
    test_speakers = None
    sp_train_path = os.path.join(data_dir, 'train_speaker_labels.npy')
    sp_test_path = os.path.join(data_dir, 'test_speaker_labels.npy')
    if os.path.exists(sp_train_path):
        train_speakers = np.load(sp_train_path)
    if os.path.exists(sp_test_path):
        test_speakers = np.load(sp_test_path)
    return train_labels, test_labels, train_speakers, test_speakers


def find_rep_files(prefix):
    pattern = os.path.join('03_results', 'models', f'{prefix}_rep*.npy')
    files = sorted(glob.glob(pattern))
    return files


def prep_2d(rep, use_pca=True):
    # rep: (N, D)
    if rep.shape[1] <= 2 and not use_pca:
        if rep.shape[1] == 1:
            return np.concatenate([rep, np.zeros((rep.shape[0], 1))], axis=1)
        return rep[:, :2]
    if rep.shape[1] <= 2:
        # pad if 1D
        if rep.shape[1] == 1:
            return np.concatenate([rep, np.zeros((rep.shape[0], 1))], axis=1)
        return rep
    pca = PCA(n_components=2)
    return pca.fit_transform(rep)


def plot_scatter(X2, labels, title, out_path, cmap='tab10', colorbar=False):
    plt.figure(figsize=(6, 5))
    sc = plt.scatter(X2[:, 0], X2[:, 1], c=labels, s=6, cmap=cmap, alpha=0.8)
    plt.title(title)
    plt.xlabel('PC1')
    plt.ylabel('PC2')
    if colorbar:
        plt.colorbar(sc)
    plt.tight_layout()
    plt.savefig(out_path, dpi=150)
    plt.close()


def train_and_eval(train_X, train_y, test_X, test_y):
    scaler = StandardScaler()
    train_X_s = scaler.fit_transform(train_X)
    test_X_s = scaler.transform(test_X)
    clf = LogisticRegression(max_iter=2000, multi_class='multinomial', solver='lbfgs')
    clf.fit(train_X_s, train_y)
    preds = clf.predict(test_X_s)
    acc = accuracy_score(test_y, preds)
    return acc


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--model-prefix', required=True, help='Prefix used when saving reps in 03_results/models (e.g. audiomnist2_rseed-42)')
    p.add_argument('--data-dir', default='01_data/avmnist_data_from_source', help='Where the train/test labels live')
    p.add_argument('--out-dir', default='03_results/plots', help='Where to save plots')
    args = p.parse_args()

    os.makedirs(args.out_dir, exist_ok=True)

    rep_files = find_rep_files(args.model_prefix)
    if len(rep_files) == 0:
        raise FileNotFoundError(f'No rep files found for prefix {args.model_prefix} in 03_results/models')

    train_labels, test_labels, train_speakers, test_speakers = load_labels(args.data_dir)
    n_train = len(train_labels)
    all_labels = np.concatenate([train_labels, test_labels])
    if (train_speakers is not None) and (test_speakers is not None):
        all_speakers = np.concatenate([train_speakers, test_speakers])
    else:
        all_speakers = None

    results = []
    for rep_path in rep_files:
        rep = np.load(rep_path)
        name = os.path.basename(rep_path).replace('.npy', '')
        print(f'Processing {name} with shape {rep.shape}')

        # split
        X_train = rep[:n_train]
        X_test = rep[n_train:]

        # 2D projection
        X2 = prep_2d(np.concatenate([X_train, X_test]), use_pca=True)
        X2_train = X2[:n_train]
        X2_test = X2[n_train:]

        # Plot digits
        out_plot_digits = os.path.join(args.out_dir, f'{name}_digits.png')
        plot_scatter(X2, all_labels, f'{name} colored by digit', out_plot_digits, cmap='tab10', colorbar=False)

        # Plot speakers (if available)
        if all_speakers is not None:
            out_plot_speakers = os.path.join(args.out_dir, f'{name}_speakers.png')
            # use continuous colormap with colorbar
            plot_scatter(X2, all_speakers, f'{name} colored by speaker', out_plot_speakers, cmap='viridis', colorbar=True)

        # Classification accuracy: train on X_train, test on X_test
        # Use original high-dim reps for classification (better)
        try:
            acc_digit = train_and_eval(X_train, train_labels, X_test, test_labels)
        except Exception as e:
            print(f'  Warning: digit classifier failed for {name}: {e}')
            acc_digit = None

        if all_speakers is not None:
            try:
                acc_speaker = train_and_eval(X_train, train_speakers, X_test, test_speakers)
            except Exception as e:
                print(f'  Warning: speaker classifier failed for {name}: {e}')
                acc_speaker = None
        else:
            acc_speaker = None

        print(f'  {name}: digit acc={acc_digit}, speaker acc={acc_speaker}')
        results.append({'rep': name, 'dim': rep.shape[1], 'digit_acc': acc_digit, 'speaker_acc': acc_speaker})

    # Save results to CSV
    df = pd.DataFrame(results)
    out_csv = os.path.join(args.out_dir, f'{args.model_prefix}_rep_accuracies.csv')
    df.to_csv(out_csv, index=False)
    print(f'Saved accuracies to {out_csv}')


if __name__ == '__main__':
    main()
