#!/usr/bin/env python3
"""Compute baseline classification accuracies on raw AVMNIST data.

This script loads the raw image and audio data (flattened) and computes
classification accuracies for all metadata labels to provide a baseline
for comparison with learned representations.

Usage:
  python 02_paper_experiments/analysis/054_audiomnist_raw_data_baselines.py --full_spectrum --seed 0
"""
import os
import argparse
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
import pandas as pd
from torchvision.datasets import MNIST


def match_datasets_by_label(images, img_labels, audio, audio_labels, speaker_labels, 
                           n_samples_per_class=6000, seed=42):
    """Match MNIST images to AudioMNIST audio, ensuring speaker labels stay aligned."""
    np.random.seed(seed)
    matched_images = []
    matched_audio = []
    matched_labels = []
    matched_speaker_labels = []
    
    for digit in range(10):
        img_indices = np.where(img_labels == digit)[0]
        audio_indices = np.where(audio_labels == digit)[0]
        
        n_samples = min(len(img_indices), len(audio_indices), n_samples_per_class)
        
        if n_samples > 0:
            img_sample_indices = np.random.choice(img_indices, n_samples, replace=False)
            audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
            
            matched_images.append(images[img_sample_indices])
            matched_audio.append(audio[audio_sample_indices])
            matched_labels.extend([digit] * n_samples)
            matched_speaker_labels.append(speaker_labels[audio_sample_indices])
            
    return (np.concatenate(matched_images, axis=0),
            np.concatenate(matched_audio, axis=0),
            np.array(matched_labels),
            np.concatenate(matched_speaker_labels, axis=0))


def compute_multi_classifier_accs(rep, labels):
    """Compute accuracies for KNN, Logistic Regression, and Gaussian NB"""
    out = {'knn': None, 'logistic': None, 'gnb': None}
    if rep is None or labels is None:
        return out
    try:
        X = np.asarray(rep)
        y = np.asarray(labels)
        # Encode non-numeric labels
        if y.dtype.kind in {'U', 'S', 'O'}:
            le = LabelEncoder()
            y = le.fit_transform(y.astype(str))
        n = X.shape[0]
        if n < 10:
            return out
        split = int(0.8 * n)
        X_train, X_test = X[:split], X[split:]
        y_train, y_test = y[:split], y[split:]
        
        # KNN
        try:
            knn = KNeighborsClassifier(n_neighbors=1)
            knn.fit(X_train, y_train)
            out['knn'] = float(knn.score(X_test, y_test))
        except:
            out['knn'] = None
        
        # Logistic Regression
        try:
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            log = LogisticRegression(max_iter=2000, solver='saga', multi_class='multinomial')
            log.fit(X_train_scaled, y_train)
            out['logistic'] = float(log.score(X_test_scaled, y_test))
        except:
            try:
                log = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='auto')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
            except:
                out['logistic'] = None
        
        # Gaussian NB
        try:
            scaler = StandardScaler()
            X_train_g = scaler.fit_transform(X_train)
            X_test_g = scaler.transform(X_test)
            gnb = GaussianNB()
            gnb.fit(X_train_g, y_train)
            out['gnb'] = float(gnb.score(X_test_g, y_test))
        except:
            out['gnb'] = None
        
        return out
    except Exception as e:
        print(f"Multi-classifier accuracy failed: {e}")
        return out


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=0, help='Random seed (matches training)')
    parser.add_argument('--full_spectrum', action='store_true')
    parser.add_argument('--data-dir', default='01_data/avmnist_data_from_source')
    parser.add_argument('--out-dir', default='03_results/processed')
    args = parser.parse_args()

    print("="*80)
    print("Computing baseline classification accuracies on raw data")
    print("="*80)

    # Load MNIST images
    mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
    mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
    
    mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0
    mnist_train_labels = mnist_train.targets.numpy()
    mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
    mnist_test_labels = mnist_test.targets.numpy()
    
    # Load audio data and speaker labels
    train_audio = np.load(os.path.join(args.data_dir, "audio/train_data.npy"))
    train_labels = np.load(os.path.join(args.data_dir, "train_labels.npy"))
    train_speaker_labels = np.load(os.path.join(args.data_dir, "train_speaker_labels.npy"))
    
    test_audio = np.load(os.path.join(args.data_dir, "audio/test_data.npy"))
    test_labels = np.load(os.path.join(args.data_dir, "test_labels.npy"))
    test_speaker_labels = np.load(os.path.join(args.data_dir, "test_speaker_labels.npy"))
    
    print(f"Loaded MNIST train: {mnist_train_images.shape}")
    print(f"Loaded audio train: {train_audio.shape}")
    
    # Match datasets (same as training)
    print("\nMatching training datasets...")
    train_images, train_audio_matched, train_labels_matched, train_speakers = match_datasets_by_label(
        mnist_train_images, mnist_train_labels, train_audio, train_labels, train_speaker_labels,
        n_samples_per_class=5000, seed=args.seed
    )
    
    print("\nMatching test datasets...")
    test_images, test_audio_matched, test_labels_matched, test_speakers = match_datasets_by_label(
        mnist_test_images, mnist_test_labels, test_audio, test_labels, test_speaker_labels,
        n_samples_per_class=1000, seed=args.seed
    )
    
    # Flatten images
    train_images_flat = train_images.reshape(train_images.shape[0], -1)
    test_images_flat = test_images.reshape(test_images.shape[0], -1)
    
    # Process audio
    if not args.full_spectrum:
        train_audio_proc = train_audio_matched.mean(axis=1)  # (N, 112)
        test_audio_proc = test_audio_matched.mean(axis=1)
    else:
        train_audio_proc = train_audio_matched.reshape(train_audio_matched.shape[0], -1)  # (N, 112*112)
        test_audio_proc = test_audio_matched.reshape(test_audio_matched.shape[0], -1)
    
    print(f"\nProcessed data shapes:")
    print(f"  Images: {train_images_flat.shape} (train), {test_images_flat.shape} (test)")
    print(f"  Audio: {train_audio_proc.shape} (train), {test_audio_proc.shape} (test)")
    
    # Concatenate train and test (same order as reps)
    all_images = np.concatenate([train_images_flat, test_images_flat], axis=0)
    all_audio = np.concatenate([train_audio_proc, test_audio_proc], axis=0)
    all_labels = np.concatenate([train_labels_matched, test_labels_matched], axis=0)
    all_speakers = np.concatenate([train_speakers, test_speakers], axis=0)
    
    # Load metadata
    def load_meta_candidates(base_dirs):
        """Try several locations for audioMNIST_meta.txt/json"""
        for base in base_dirs:
            for fname in ('audioMNIST_meta.txt', 'audioMNIST_meta.json', 'audioMNIST_meta'):
                p = os.path.join(base, fname)
                if os.path.exists(p):
                    try:
                        import json
                        with open(p, 'r') as f:
                            return json.load(f)
                    except:
                        pass
        return None
    
    meta = load_meta_candidates([args.data_dir, '01_data/processed/avmnist', 'AudioMNIST', '01_data/processed/avmnist'])
    if meta is None:
        print('\nWarning: audioMNIST metadata not found. Only digit/speaker baselines will be computed.')
    
    def map_meta_field(speakers_array, field):
        """Map speaker IDs to metadata field values"""
        if meta is None:
            return None
        out = []
        for s in speakers_array:
            candidates = [str(s), str(int(s) + 1)]
            val = None
            for c in candidates:
                if c in meta:
                    val = meta[c].get(field)
                    break
            if val is None:
                try:
                    val = meta.get(int(s), {}).get(field)
                except:
                    pass
            out.append(val if val is not None else 'UNKNOWN')
        return np.array(out, dtype=object)
    
    # Extract all metadata fields
    gender_all = map_meta_field(all_speakers, 'gender')
    native_all = map_meta_field(all_speakers, 'native speaker')
    accent_all = map_meta_field(all_speakers, 'accent')
    origin_all = map_meta_field(all_speakers, 'origin')
    
    # Extract continent from origin (first word)
    if origin_all is not None:
        origin_all = np.array([str(o).split()[0] if o != 'UNKNOWN' else o for o in origin_all], dtype=object)
    
    room_all = map_meta_field(all_speakers, 'recordingroom')
    
    # Compute accuracies for images, audio, and concatenated
    print("\n" + "="*80)
    print("Computing classification accuracies...")
    print("="*80)
    
    results = []
    
    # Define data modalities to test
    modalities = [
        ('raw_image', all_images),
        ('raw_audio', all_audio),
        ('raw_concat', np.concatenate([all_images, all_audio], axis=1))
    ]
    
    for mod_name, mod_data in modalities:
        print(f"\n{mod_name} (shape: {mod_data.shape}):")
        
        # Digit classification
        print("  Digits...")
        accs_digit = compute_multi_classifier_accs(mod_data, all_labels)
        results.append({
            'modality': mod_name,
            'label_type': 'digit',
            'knn_acc': accs_digit['knn'],
            'logistic_acc': accs_digit['logistic'],
            'gnb_acc': accs_digit['gnb']
        })
        knn_str = f"{accs_digit['knn']:.3f}" if accs_digit['knn'] is not None else 'N/A'
        log_str = f"{accs_digit['logistic']:.3f}" if accs_digit['logistic'] is not None else 'N/A'
        gnb_str = f"{accs_digit['gnb']:.3f}" if accs_digit['gnb'] is not None else 'N/A'
        print(f"    KNN={knn_str}, Logistic={log_str}, GNB={gnb_str}")
        
        # Speaker classification
        print("  Speakers...")
        accs_speaker = compute_multi_classifier_accs(mod_data, all_speakers)
        results.append({
            'modality': mod_name,
            'label_type': 'speaker',
            'knn_acc': accs_speaker['knn'],
            'logistic_acc': accs_speaker['logistic'],
            'gnb_acc': accs_speaker['gnb']
        })
        knn_str = f"{accs_speaker['knn']:.3f}" if accs_speaker['knn'] is not None else 'N/A'
        log_str = f"{accs_speaker['logistic']:.3f}" if accs_speaker['logistic'] is not None else 'N/A'
        gnb_str = f"{accs_speaker['gnb']:.3f}" if accs_speaker['gnb'] is not None else 'N/A'
        print(f"    KNN={knn_str}, Logistic={log_str}, GNB={gnb_str}")
        
        # Metadata classifications
        if gender_all is not None:
            print("  Gender...")
            accs_gender = compute_multi_classifier_accs(mod_data, gender_all)
            results.append({
                'modality': mod_name,
                'label_type': 'gender',
                'knn_acc': accs_gender['knn'],
                'logistic_acc': accs_gender['logistic'],
                'gnb_acc': accs_gender['gnb']
            })
            knn_str = f"{accs_gender['knn']:.3f}" if accs_gender['knn'] is not None else 'N/A'
            log_str = f"{accs_gender['logistic']:.3f}" if accs_gender['logistic'] is not None else 'N/A'
            gnb_str = f"{accs_gender['gnb']:.3f}" if accs_gender['gnb'] is not None else 'N/A'
            print(f"    KNN={knn_str}, Logistic={log_str}, GNB={gnb_str}")
        
        if native_all is not None:
            print("  Native speaker...")
            accs_native = compute_multi_classifier_accs(mod_data, native_all)
            results.append({
                'modality': mod_name,
                'label_type': 'native',
                'knn_acc': accs_native['knn'],
                'logistic_acc': accs_native['logistic'],
                'gnb_acc': accs_native['gnb']
            })
        
        if accent_all is not None:
            print("  Accent...")
            accs_accent = compute_multi_classifier_accs(mod_data, accent_all)
            results.append({
                'modality': mod_name,
                'label_type': 'accent',
                'knn_acc': accs_accent['knn'],
                'logistic_acc': accs_accent['logistic'],
                'gnb_acc': accs_accent['gnb']
            })
        
        if origin_all is not None:
            print("  Origin...")
            accs_origin = compute_multi_classifier_accs(mod_data, origin_all)
            results.append({
                'modality': mod_name,
                'label_type': 'origin',
                'knn_acc': accs_origin['knn'],
                'logistic_acc': accs_origin['logistic'],
                'gnb_acc': accs_origin['gnb']
            })
            knn_str = f"{accs_origin['knn']:.3f}" if accs_origin['knn'] is not None else 'N/A'
            log_str = f"{accs_origin['logistic']:.3f}" if accs_origin['logistic'] is not None else 'N/A'
            gnb_str = f"{accs_origin['gnb']:.3f}" if accs_origin['gnb'] is not None else 'N/A'
            print(f"    KNN={knn_str}, Logistic={log_str}, GNB={gnb_str}")
        
        if room_all is not None:
            print("  Room...")
            accs_room = compute_multi_classifier_accs(mod_data, room_all)
            results.append({
                'modality': mod_name,
                'label_type': 'room',
                'knn_acc': accs_room['knn'],
                'logistic_acc': accs_room['logistic'],
                'gnb_acc': accs_room['gnb']
            })
            knn_str = f"{accs_room['knn']:.3f}" if accs_room['knn'] is not None else 'N/A'
            log_str = f"{accs_room['logistic']:.3f}" if accs_room['logistic'] is not None else 'N/A'
            gnb_str = f"{accs_room['gnb']:.3f}" if accs_room['gnb'] is not None else 'N/A'
            print(f"    KNN={knn_str}, Logistic={log_str}, GNB={gnb_str}")
    
    # Save results
    os.makedirs(args.out_dir, exist_ok=True)
    df = pd.DataFrame(results)
    out_csv = os.path.join(args.out_dir, f'audiomnist{"_fullspec" if args.full_spectrum else ""}_raw_data_baselines_rseed-{args.seed}.csv')
    df.to_csv(out_csv, index=False)
    
    print("\n" + "="*80)
    print(f"Saved baseline accuracies to: {out_csv}")
    print("="*80)


if __name__ == '__main__':
    main()
