import torch
import torch.nn as nn
import sys
import os
import argparse
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
from torchvision.datasets import MNIST

# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.models.larrp_unimodal import AdaptiveRankReducedLinear
from src.functions.train_avmnist import train_overcomplete_ae_with_pretrained

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train Model on AVMNIST with real images")
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use.')
    parser.add_argument('--full_spectrum', action='store_true', help='Use full 112x112 audio spectrogram instead of averaging to 112')
    args = parser.parse_args()

    # --- 1. Load Data from the New, Correctly Processed Source ---
    print("Loading data from the newly prepared AudioMNIST source...")
    
    # Load real MNIST images
    mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
    mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
    
    mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0
    mnist_train_labels = mnist_train.targets.numpy()
    mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
    mnist_test_labels = mnist_test.targets.numpy()
    
    # Load audio data and NEW speaker labels from our prepared directory
    data_dir = "01_data/avmnist_data_from_source"
    
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy"))
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    train_speaker_labels = np.load(os.path.join(data_dir, "train_speaker_labels.npy")) # <-- NEW
    
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy"))
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))
    test_speaker_labels = np.load(os.path.join(data_dir, "test_speaker_labels.npy")) # <-- NEW
    
    print(f"MNIST train images shape: {mnist_train_images.shape}")
    print(f"Prepared train audio shape: {train_audio.shape}")
    print(f"Prepared train speaker labels shape: {train_speaker_labels.shape}") # <-- NEW
    
    # --- 2. Update the Matching Function to Handle Speaker Labels ---
    def match_datasets_by_label(images, img_labels, audio, audio_labels, speaker_labels, n_samples_per_class=6000, seed=42):
        """
        Matches MNIST images to AudioMNIST audio, ensuring speaker labels stay aligned.
        """
        np.random.seed(seed)
        matched_images = []
        matched_audio = []
        matched_labels = []
        matched_speaker_labels = [] # <-- NEW
        
        for digit in range(10):
            img_indices = np.where(img_labels == digit)[0]
            audio_indices = np.where(audio_labels == digit)[0]
            
            n_samples = min(len(img_indices), len(audio_indices), n_samples_per_class)
            
            if n_samples > 0:
                img_sample_indices = np.random.choice(img_indices, n_samples, replace=False)
                audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
                
                matched_images.append(images[img_sample_indices])
                matched_audio.append(audio[audio_sample_indices])
                matched_labels.extend([digit] * n_samples)
                matched_speaker_labels.append(speaker_labels[audio_sample_indices]) # <-- NEW: Sample speakers with the same indices as audio
                
        return (np.concatenate(matched_images, axis=0),
                np.concatenate(matched_audio, axis=0),
                np.array(matched_labels),
                np.concatenate(matched_speaker_labels, axis=0)) # <-- NEW: Return matched speakers

    # --- 3. Match Datasets and Keep Speaker Labels ---
    print("\nMatching training datasets...")
    train_images, train_audio, train_labels, train_speaker_labels = match_datasets_by_label(
        mnist_train_images, mnist_train_labels, train_audio, train_labels, train_speaker_labels,
        n_samples_per_class=5000, seed=args.seed
    )
    
    print("\nMatching test datasets...")
    test_images, test_audio, test_labels, test_speaker_labels = match_datasets_by_label(
        mnist_test_images, mnist_test_labels, test_audio, test_labels, test_speaker_labels,
        n_samples_per_class=1000, seed=args.seed
    )

    # --- 4. Prepare Data for the Model ---
    # Reshape images to flat vectors (28x28 -> 784)
    train_images_flat = train_images.reshape(train_images.shape[0], -1)
    test_images_flat = test_images.reshape(test_images.shape[0], -1)

    # If not using full spectrum, average the audio data
    if not args.full_spectrum:
        train_audio = train_audio.mean(axis=1) # (N, 112, 112) -> (N, 112)
        test_audio = test_audio.mean(axis=1)   # (N, 112, 112) -> (N, 112)
    #else:
    #    # Flatten the spectrograms if using the full spectrum
    #    train_audio = train_audio.reshape(train_audio.shape[0], -1) # (N, 112, 112) -> (N, 12544)
    #    test_audio = test_audio.reshape(test_audio.shape[0], -1)   # (N, 112, 112) -> (N, 12544)

    print(f"\nFinal dataset shapes:")
    print(f"Train Images: {train_images_flat.shape}")
    print(f"Train Audio: {train_audio.shape}")
    print(f"Train Labels: {train_labels.shape}")
    print(f"Train Speaker Labels: {train_speaker_labels.shape}") # <-- VERIFICATION
    print(f"Test Images: {test_images_flat.shape}")
    print(f"Test Audio: {test_audio.shape}")
    print(f"Test Labels: {test_labels.shape}")
    print(f"Test Speaker Labels: {test_speaker_labels.shape}")   # <-- VERIFICATION

    # --- 5. Proceed with Model Training ---
    n_train_samples = train_images_flat.shape[0]
    data = [torch.FloatTensor(np.concatenate([train_images_flat, test_images_flat], axis=0)),
            torch.FloatTensor(np.concatenate([train_audio, test_audio], axis=0))]
    #print(data[0].shape, data[1].shape)
    #exit()
    
    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {DEVICE}")
    
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    
    # --- Your Model Hyperparameters and Training Call ---
    rank_reduction_frequency = 10
    rank_reduction_threshold = 0.01
    early_stopping = 50
    patience = 10
    r_square_threshold = 0.05

    class Args:
        def __init__(self):
            self.latent_dim = 200
            self.batch_size = 512
            self.lr = 1e-3
            self.weight_decay = 2e-5
            self.dropout = 0.0
            self.ae_depth = 2
            self.ae_width = 0.5
            self.epochs = 20000
            self.num_workers = 8
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu

    train_args = Args()
    
    print("\nStarting model training...")
    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae_with_pretrained(
        data, 
        n_train_samples,
        train_args.latent_dim, 
        DEVICE,
        train_args,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=early_stopping,
        initial_rank_ratio=1.0,
        rank_reduction_frequency=rank_reduction_frequency,
        rank_reduction_threshold=rank_reduction_threshold,
        warmup_epochs=early_stopping,
        patience=patience,
        min_rank=1,
        r_square_threshold=r_square_threshold,
        threshold_type='absolute',
        compressibility_type='direct',
        verbose=True,
        compute_jacobian=False,
        sharedwhenall=False,
        pretrained_name=f"audiomnist{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}",
        lr_schedule='step',
        decision_metric='ExVarScore',
        full_spectrum=args.full_spectrum
    )
    
    # Save results
    model_prefix = f"audiomnist{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}"
    os.makedirs("03_results/models", exist_ok=True)
    torch.save(model.state_dict(), f"03_results/models/{model_prefix}_final_model.pth")
    for i, rep in enumerate(reps):
        np.save(f"03_results/models/{model_prefix}_rep{i}.npy", rep.cpu().numpy())
    pd.DataFrame(rank_history).to_csv(f"03_results/models/{model_prefix}_rank_history.csv", index=False)
    
    print(f"\nExperiment completed! Model and results saved with prefix '{model_prefix}'")

