import torch
import torch.nn as nn
import sys
import os
import argparse
import numpy as np
import random
import pandas as pd
import csv
from torchvision.datasets import MNIST
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB

# Add src to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.functions.pretrain_avmnist import ImageAutoencoder, AudioAutoencoder, train_multimodal_ae_with_pretrained


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train multimodal model with pretrained encoders/decoders")
    parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use')
    parser.add_argument('--full_spectrum', action='store_true', 
                       help='Use full 112x112 audio spectrogram instead of averaging to 112')
    parser.add_argument('--pretrained_dir', type=str, default='03_results/models/pretrained_unimodal',
                       help='Directory containing pretrained models')
    args = parser.parse_args()

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)

    # Setup device
    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {DEVICE}")

    # --- 1. Load Data (same as 056_audiomnist.py) ---
    print("\n" + "="*80)
    print("Loading data from AudioMNIST source...")
    print("="*80)
    
    # Load real MNIST images
    mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
    mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
    
    mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0
    mnist_train_labels = mnist_train.targets.numpy()
    mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
    mnist_test_labels = mnist_test.targets.numpy()
    
    # Load audio data and speaker labels
    data_dir = "01_data/avmnist_data_from_source"
    
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy"))
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    train_speaker_labels = np.load(os.path.join(data_dir, "train_speaker_labels.npy"))
    
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy"))
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))
    test_speaker_labels = np.load(os.path.join(data_dir, "test_speaker_labels.npy"))
    
    print(f"MNIST train images shape: {mnist_train_images.shape}")
    print(f"Audio train shape: {train_audio.shape}")
    print(f"Audio train speaker labels shape: {train_speaker_labels.shape}")

    # --- 2. Match Datasets (same as 056_audiomnist.py) ---
    def match_datasets_by_label(images, img_labels, audio, audio_labels, speaker_labels, 
                               n_samples_per_class=6000, seed=42):
        """
        Matches MNIST images to AudioMNIST audio, ensuring speaker labels stay aligned.
        """
        np.random.seed(seed)
        matched_images = []
        matched_audio = []
        matched_labels = []
        matched_speaker_labels = []
        
        for digit in range(10):
            img_indices = np.where(img_labels == digit)[0]
            audio_indices = np.where(audio_labels == digit)[0]
            
            n_samples = min(len(img_indices), len(audio_indices), n_samples_per_class)
            
            if n_samples > 0:
                img_sample_indices = np.random.choice(img_indices, n_samples, replace=False)
                audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
                
                matched_images.append(images[img_sample_indices])
                matched_audio.append(audio[audio_sample_indices])
                matched_labels.extend([digit] * n_samples)
                matched_speaker_labels.append(speaker_labels[audio_sample_indices])
                
        return (np.concatenate(matched_images, axis=0),
                np.concatenate(matched_audio, axis=0),
                np.array(matched_labels),
                np.concatenate(matched_speaker_labels, axis=0))

    print("\nMatching training datasets...")
    train_images, train_audio, train_labels, train_speaker_labels = match_datasets_by_label(
        mnist_train_images, mnist_train_labels, train_audio, train_labels, train_speaker_labels,
        n_samples_per_class=5000, seed=args.seed
    )
    
    print("\nMatching test datasets...")
    test_images, test_audio, test_labels, test_speaker_labels = match_datasets_by_label(
        mnist_test_images, mnist_test_labels, test_audio, test_labels, test_speaker_labels,
        n_samples_per_class=1000, seed=args.seed
    )

    # --- 3. Prepare Data ---
    train_images_flat = train_images.reshape(train_images.shape[0], -1)
    test_images_flat = test_images.reshape(test_images.shape[0], -1)

    if not args.full_spectrum:
        train_audio = train_audio.mean(axis=1)
        test_audio = test_audio.mean(axis=1)

    print(f"\nFinal dataset shapes:")
    print(f"Train Images: {train_images_flat.shape}")
    print(f"Train Audio: {train_audio.shape}")
    print(f"Train Labels: {train_labels.shape}")
    print(f"Train Speaker Labels: {train_speaker_labels.shape}")

    n_train_samples = train_images_flat.shape[0]
    data = [
        torch.FloatTensor(np.concatenate([train_images_flat, test_images_flat], axis=0)),
        torch.FloatTensor(np.concatenate([train_audio, test_audio], axis=0))
    ]

    # --- 4. Load Pretrained Encoders ---
    print("\n" + "="*80)
    print("Loading pretrained encoders...")
    print("="*80)
    
    # Image autoencoder doesn't use full_spectrum flag, only audio does
    img_model_prefix = f"audiomnist_rseed-{args.seed}"
    audio_model_prefix = f"audiomnist{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}"
    
    # Load image autoencoder
    img_ae_path = os.path.join(args.pretrained_dir, f"{img_model_prefix}_image_ae.pth")
    if not os.path.exists(img_ae_path):
        raise FileNotFoundError(f"Image autoencoder not found: {img_ae_path}")
    
    img_checkpoint = torch.load(img_ae_path, map_location='cpu')
    img_ae = ImageAutoencoder(latent_dim=500)  # Use 500 as in pretrained models
    img_ae.encoder.load_state_dict(img_checkpoint['encoder_state_dict'])
    img_ae.decoder.load_state_dict(img_checkpoint['decoder_state_dict'])
    print(f"✓ Loaded image autoencoder from: {img_ae_path}")
    print(f"  Image AE latent dim: {img_checkpoint['latent_dim']}")
    print(f"  Image AE final val loss: {img_checkpoint['final_val_loss']:.6f}")
    
    # Load audio autoencoder
    audio_ae_path = os.path.join(args.pretrained_dir, f"{audio_model_prefix}_audio_ae.pth")
    if not os.path.exists(audio_ae_path):
        raise FileNotFoundError(f"Audio autoencoder not found: {audio_ae_path}")
    
    audio_checkpoint = torch.load(audio_ae_path, map_location='cpu')
    audio_ae = AudioAutoencoder(latent_dim=500, full_spectrum=args.full_spectrum)  # Use 500
    audio_ae.encoder.load_state_dict(audio_checkpoint['encoder_state_dict'])
    audio_ae.decoder.load_state_dict(audio_checkpoint['decoder_state_dict'])
    print(f"✓ Loaded audio autoencoder from: {audio_ae_path}")
    print(f"  Audio AE latent dim: {audio_checkpoint['latent_dim']}")
    print(f"  Audio AE final val loss: {audio_checkpoint['final_val_loss']:.6f}")

    # Extract encoder modules
    pretrained_encoders = [img_ae.encoder, audio_ae.encoder]
    # Extract decoder modules so we can reuse pretrained decoders as well
    pretrained_decoders = [img_ae.decoder, audio_ae.decoder]
    
    # --- 5. Training Configuration (same as 056_audiomnist.py) ---
    rank_reduction_frequency = 10
    rank_reduction_threshold = 0.01
    early_stopping = 50
    patience = 10
    r_square_threshold = 0.05

    class TrainArgs:
        def __init__(self):
            self.latent_dim = 500  # Use 500 for all latent dims to match pretrained encoders
            self.batch_size = 512
            self.lr = 1e-3
            self.weight_decay = 2e-5
            self.dropout = 0.0
            self.ae_depth = 2
            self.ae_width = 0.5
            self.epochs = 10000
            self.num_workers = 8
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu
            self.freeze_models = False

    train_args = TrainArgs()
    
    print("\n" + "="*80)
    print("Starting multimodal training with pretrained encoders...")
    print("="*80)
    
    # Call the training function with pretrained encoders
    model, reps, train_loss, r_squares, rank_history, loss_curves = train_multimodal_ae_with_pretrained(
        data, 
        n_train_samples,
        train_args.latent_dim, 
        DEVICE,
        train_args,
        pretrained_encoders=pretrained_encoders,
        pretrained_decoders=pretrained_decoders,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=early_stopping,
        initial_rank_ratio=1.0,
        rank_reduction_frequency=rank_reduction_frequency,
        rank_reduction_threshold=rank_reduction_threshold,
        warmup_epochs=early_stopping,
        patience=patience,
        min_rank=1,
        r_square_threshold=r_square_threshold,
        threshold_type='absolute',
        compressibility_type='direct',
        verbose=True,
        compute_jacobian=False,
        sharedwhenall=False,
        pretrained_name=None,  # No pretrained name since we're training multimodal
        lr_schedule='step',
        decision_metric='ExVarScore',
        full_spectrum=args.full_spectrum,
        freeze_models=train_args.freeze_models
    )
    
    # --- 6. Save Results (same as 056_audiomnist.py) ---
    save_prefix = f"audiomnist{'_fullspec' if args.full_spectrum else ''}_pretrained_rseed-{args.seed}"
    if train_args.freeze_models:
        save_prefix += "_frozen"
    os.makedirs("03_results/models", exist_ok=True)
    torch.save(model.state_dict(), f"03_results/models/{save_prefix}_final_model.pth")
    for i, rep in enumerate(reps):
        np.save(f"03_results/models/{save_prefix}_rep{i}.npy", rep.cpu().numpy())
    pd.DataFrame(rank_history).to_csv(f"03_results/models/{save_prefix}_rank_history.csv", index=False)
    
    print("\n" + "="*80)
    print("Training complete!")
    print("="*80)
    
    # --- 7. Post-Training Analysis: Metadata Classification ---
    print("\n" + "="*80)
    print("Computing metadata classification metrics...")
    print("="*80)
    
    # Reconstruct matched labels using the same procedure
    def match_datasets_by_label_simple(audio_labels, speaker_labels, n_samples_per_class=6000, seed=42):
        """Simplified version that returns matched labels and speakers"""
        np.random.seed(seed)
        matched_labels = []
        matched_speaker_labels = []
        for digit in range(10):
            audio_indices = np.where(audio_labels == digit)[0]
            n_samples = min(len(audio_indices), n_samples_per_class)
            if n_samples > 0:
                audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
                matched_labels.extend([digit] * n_samples)
                matched_speaker_labels.append(speaker_labels[audio_sample_indices])
        return np.array(matched_labels), np.concatenate(matched_speaker_labels, axis=0)
    
    # Load source audio labels
    data_dir = "01_data/avmnist_data_from_source"
    train_labels_src = np.load(os.path.join(data_dir, "train_labels.npy"))
    train_speakers_src = np.load(os.path.join(data_dir, "train_speaker_labels.npy"))
    test_labels_src = np.load(os.path.join(data_dir, "test_labels.npy"))
    test_speakers_src = np.load(os.path.join(data_dir, "test_speaker_labels.npy"))
    
    # Match using same parameters as training
    matched_train_labels, matched_train_speakers = match_datasets_by_label_simple(
        train_labels_src, train_speakers_src, n_samples_per_class=5000, seed=args.seed
    )
    matched_test_labels, matched_test_speakers = match_datasets_by_label_simple(
        test_labels_src, test_speakers_src, n_samples_per_class=1000, seed=args.seed
    )
    
    # Concatenate to match training data ordering (train then test)
    labels_concat = np.concatenate([matched_train_labels, matched_test_labels], axis=0)
    speakers_concat = np.concatenate([matched_train_speakers, matched_test_speakers], axis=0)
    
    # Load metadata
    def load_meta_candidates(base_dirs):
        """Try several locations for audioMNIST_meta.txt/json"""
        for base in base_dirs:
            for fname in ('audioMNIST_meta.txt', 'audioMNIST_meta.json', 'audioMNIST_meta'):
                p = os.path.join(base, fname)
                if os.path.exists(p):
                    try:
                        import json
                        with open(p, 'r') as f:
                            return json.load(f)
                    except:
                        pass
        return None
    
    meta = load_meta_candidates([data_dir, '01_data/processed/avmnist', 'AudioMNIST', '01_data/processed/avmnist'])
    if meta is None:
        print('Warning: audioMNIST metadata not found. Gender/native/accent labels will be unavailable.')
    
    def map_meta_field(speakers_array, field):
        """Map speaker IDs to metadata field values"""
        if meta is None:
            return None
        out = []
        for s in speakers_array:
            candidates = [str(s), str(int(s) + 1)]
            val = None
            for c in candidates:
                if c in meta:
                    val = meta[c].get(field)
                    break
            if val is None:
                try:
                    val = meta.get(int(s), {}).get(field)
                except:
                    pass
            out.append(val if val is not None else 'UNKNOWN')
        return np.array(out, dtype=object)
    
    # Extract all metadata fields
    gender_concat = map_meta_field(speakers_concat, 'gender')
    native_concat = map_meta_field(speakers_concat, 'native speaker')
    accent_concat = map_meta_field(speakers_concat, 'accent')
    origin_concat = map_meta_field(speakers_concat, 'origin')
    
    # Extract continent from origin (first word)
    if origin_concat is not None:
        origin_concat = np.array([str(o).split()[0] if o != 'UNKNOWN' else o for o in origin_concat], dtype=object)
    
    room_concat = map_meta_field(speakers_concat, 'recordingroom')
    
    # Helper functions for classification
    def compute_multi_classifier_accs(rep, labels):
        """Compute accuracies for KNN, Logistic Regression, and Gaussian NB"""
        out = {'knn': None, 'logistic': None, 'gnb': None}
        if rep is None or labels is None:
            return out
        try:
            X = np.asarray(rep)
            y = np.asarray(labels)
            # Encode non-numeric labels
            if y.dtype.kind in {'U', 'S', 'O'}:
                le = LabelEncoder()
                y = le.fit_transform(y.astype(str))
            n = X.shape[0]
            if n < 10:
                return out
            split = int(0.8 * n)
            X_train, X_test = X[:split], X[split:]
            y_train, y_test = y[:split], y[split:]
            
            # KNN
            try:
                knn = KNeighborsClassifier(n_neighbors=1)
                knn.fit(X_train, y_train)
                out['knn'] = float(knn.score(X_test, y_test))
            except:
                out['knn'] = None
            
            # Logistic Regression
            try:
                scaler = StandardScaler()
                X_train_scaled = scaler.fit_transform(X_train)
                X_test_scaled = scaler.transform(X_test)
                log = LogisticRegression(max_iter=2000, solver='saga', multi_class='multinomial')
                log.fit(X_train_scaled, y_train)
                out['logistic'] = float(log.score(X_test_scaled, y_test))
            except:
                try:
                    log = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='auto')
                    log.fit(X_train_scaled, y_train)
                    out['logistic'] = float(log.score(X_test_scaled, y_test))
                except:
                    out['logistic'] = None
            
            # Gaussian NB
            try:
                scaler = StandardScaler()
                X_train_g = scaler.fit_transform(X_train)
                X_test_g = scaler.transform(X_test)
                gnb = GaussianNB()
                gnb.fit(X_train_g, y_train)
                out['gnb'] = float(gnb.score(X_test_g, y_test))
            except:
                out['gnb'] = None
            
            return out
        except Exception as e:
            print(f"Multi-classifier accuracy failed: {e}")
            return out
    
    # Prepare label types
    label_types = [
        ('digit', labels_concat),
        ('speaker', speakers_concat)
    ]
    if origin_concat is not None:
        label_types.append(('origin', origin_concat))
    if room_concat is not None:
        label_types.append(('room', room_concat))
    if gender_concat is not None:
        label_types.append(('gender', gender_concat))
    if native_concat is not None:
        label_types.append(('native', native_concat))
    if accent_concat is not None:
        label_types.append(('accent', accent_concat))
    
    # Compute accuracies for each subspace and label type
    subspace_names = ['shared', 'modality1', 'modality2']
    results = []
    
    for i, (subspace_name, rep) in enumerate(zip(subspace_names, reps)):
        if rep is None:
            print(f"  Warning: {subspace_name} representation is None, skipping")
            continue
        
        rep_arr = rep.cpu().numpy() if torch.is_tensor(rep) else np.asarray(rep)
        print(f"\n  Analyzing {subspace_name} (shape: {rep_arr.shape}):")
        
        for label_name, label_vals in label_types:
            accs = compute_multi_classifier_accs(rep_arr, label_vals)
            results.append({
                'subspace': subspace_name,
                'label_type': label_name,
                'knn_acc': accs['knn'],
                'logistic_acc': accs['logistic'],
                'gnb_acc': accs['gnb']
            })
            knn_str = f"{accs['knn']:.3f}" if accs['knn'] is not None else 'N/A'
            log_str = f"{accs['logistic']:.3f}" if accs['logistic'] is not None else 'N/A'
            gnb_str = f"{accs['gnb']:.3f}" if accs['gnb'] is not None else 'N/A'
            print(f"    {label_name}: KNN={knn_str}, Logistic={log_str}, GNB={gnb_str}")
    
    # Save results to CSV
    os.makedirs("03_results/processed", exist_ok=True)
    results_csv = f"03_results/processed/{save_prefix}_metadata_classification.csv"
    
    with open(results_csv, 'w', newline='') as f:
        fieldnames = ['subspace', 'label_type', 'knn_acc', 'logistic_acc', 'gnb_acc']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        for row in results:
            writer.writerow(row)
    
    print(f"\n✓ Saved metadata classification results to: {results_csv}")
    
    print("\n" + "="*80)
    print(f"\nExperiment completed! Model and results saved with prefix '{save_prefix}'")
    print("="*80)
