###
# using real MNIST images with AVMNIST audio data, architecture based on the MVAE from MultiBench
# https://github.com/pliang279/MultiBench/blob/main/examples/multimedia/avmnist_MVAE_mixed.py
# This version uses actual MNIST digit images instead of eigendigits
###

import torch
import torch.nn as nn
import sys
import os
import requests
import tarfile
import argparse
from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import numpy as np
import random
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
import pandas as pd
from torchvision.datasets import MNIST
from sklearn.decomposition import PCA

# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.models.larrp_unimodal import AdaptiveRankReducedLinear
from src.functions.train_avmnist import train_overcomplete_ae_with_pretrained

def prepare_data_from_archive(data_dir="avmnist_data"):
    """
    Checks for a manually downloaded 'avmnist.tar.gz' and extracts it.
    """
    os.makedirs(data_dir, exist_ok=True)
    archive_path = os.path.join(data_dir, "avmnist.tar.gz")

    # Check if data is already extracted
    if os.path.exists(os.path.join(data_dir, "image")):
        print("AVMNIST data already extracted.")
        return

    # Check if the user has manually placed the archive
    if not os.path.exists(archive_path):
        print("\n--- ACTION REQUIRED ---")
        print(f"Could not find '{archive_path}'.")
        print("Please manually download 'avmnist.tar.gz' and place it in the 'avmnist_data' directory.")
        print("You can find it here: https://drive.google.com/file/d/1KvKynJJca5tDtI5Mmp6CoRh9pQywH8Xp/view?usp=sharing")
        print("-----------------------\n")
        sys.exit(1) # Exit because we cannot proceed without the data.

    print(f"Found '{archive_path}'. Extracting files...")
    # Use tarfile to extract the .tar.gz archive
    try:
        with tarfile.open(archive_path, "r:gz") as tar:
            tar.extractall(path=data_dir)
        print("Extraction complete.")
    except tarfile.ReadError as e:
        print(f"Error extracting the tar.gz file: {e}")
        print("The file may be corrupted. Please try downloading it again.")
        sys.exit(1)

"""
def get_manual_avmnist_dataloader(batch_size=64, data_dir="avmnist_data"):
    '''
    Manually downloads, extracts, and prepares the AVMNIST dataset.
    Returns train, validation, and test dataloaders.
    '''
    # This function now handles extraction from a manually placed .tar.gz file
    prepare_data_from_archive(data_dir)
    
    # Load from the .npy files as per the MultiBench source code
    train_images = np.load(os.path.join(data_dir, "image/train_data.npy")) / 255.0
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy")) / 255.0
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    
    test_images = np.load(os.path.join(data_dir, "image/test_data.npy")) / 255.0
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy")) / 255.0
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))

    # Create custom datasets
    train_full_dataset = AVMNISTDataset((train_images, train_audio, train_labels))
    test_dataset = AVMNISTDataset((test_images, test_audio, test_labels))

    # Split training data into train and validation
    train_size = 55000
    val_size = len(train_full_dataset) - train_size
    train_dataset, val_dataset = random_split(train_full_dataset, [train_size, val_size])

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print("AVMNIST DataLoaders created successfully from manually downloaded files.")
    return train_loader, valid_loader, test_loader
"""

"""
def get_original_images():
    num_components = 16
    train_images_raw = MNIST(root='./01_data/processed/MNIST', train=True, download=True)
    x_train = train_images_raw.data.numpy()
    num_pixels = x_train.shape[1] * x_train.shape[2]
    x_train_flat = x_train.reshape(-1, num_pixels).astype('float32') / 255.0
    pca = PCA(n_components=num_components)
    pca.fit(x_train_flat)
    pca_components = pca.components_  # The "eigendigits"
    mean_image = pca.mean_           # The average digit image
    return pca_components, mean_image
"""

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train AdaptiveRankReducedAE on AVMNIST")
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use.')
    parser.add_argument('--full_spectrum', action='store_true', help='Use full 112x112 audio spectrogram instead of averaging to 112')
    args = parser.parse_args()

    # --- 1. Get the AVMNIST dataloader using the new manual function ---
    # The dataloader will yield batches of [img_data, audio_data, label]
    #try:
    #    train_loader, valid_loader, test_loader = get_manual_avmnist_dataloader(batch_size=64, data_dir="01_data/processed/avmnist")
    #except Exception as e:
    #    print(f"Could not load data manually. Check your internet connection or file paths.")
    #    print(f"Error: {e}")
    #    sys.exit(1)
    #prepare_data_from_archive(data_dir="01_data/processed")
    # get pca components and mean image
    #pca_components, mean_image = get_original_images()
    
    # Load real MNIST images
    print("Loading real MNIST dataset...")
    mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
    mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
    
    # Get MNIST images and labels
    mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0  # Normalize to [0,1]
    mnist_train_labels = mnist_train.targets.numpy()
    mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
    mnist_test_labels = mnist_test.targets.numpy()
    
    # Load audio data from AVMNIST
    data_dir = "01_data/processed/avmnist"
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy")) / 255.0
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy")) / 255.0
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))
    
    print(f"MNIST train images shape: {mnist_train_images.shape}")
    print(f"MNIST train labels shape: {mnist_train_labels.shape}")
    print(f"AVMNIST train audio shape: {train_audio.shape}")
    print(f"AVMNIST train labels shape: {train_labels.shape}")
    
    # Match MNIST images to AVMNIST labels by sorting both datasets by label
    def match_datasets_by_label(images, img_labels, audio, audio_labels, n_samples_per_class=6000, save_mapping=True, seed=42):
        """
        Match MNIST images to AVMNIST audio by ensuring same label distribution.
        Takes n_samples_per_class samples from each digit class (0-9).
        Uses diverse MNIST images per class to avoid identical thumbnails.
        Creates and saves a reproducible mapping of MNIST image indices to AVMNIST samples.
        """
        matched_images = []
        matched_audio = []
        matched_labels = []
        mnist_image_indices = []  # Track which MNIST images were used
        
        for digit in range(10):
            # Get indices for this digit in both datasets
            img_indices = np.where(img_labels == digit)[0]
            audio_indices = np.where(audio_labels == digit)[0]
            
            # Take minimum available samples or requested samples
            n_samples = min(len(img_indices), len(audio_indices), n_samples_per_class)
            
            if n_samples > 0:
                # Sample different images for diversity: cycle through img_indices if needed
                if len(img_indices) >= n_samples:
                    # Use different images without replacement
                    img_sample_indices = np.random.choice(img_indices, n_samples, replace=False)
                else:
                    # If not enough unique images, cycle through available ones
                    img_sample_indices = np.tile(img_indices, (n_samples // len(img_indices)) + 1)[:n_samples]
                    np.random.shuffle(img_sample_indices)
                
                # Sample audio normally
                audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
                
                matched_images.append(images[img_sample_indices])
                matched_audio.append(audio[audio_sample_indices])
                matched_labels.extend([digit] * n_samples)
                mnist_image_indices.extend(img_sample_indices.tolist())
                
                print(f"Digit {digit}: matched {n_samples} samples (using {min(len(img_indices), n_samples)} unique images)")
        
        # Return matched arrays and MNIST indices for combined mapping
        return (np.concatenate(matched_images, axis=0),
                np.concatenate(matched_audio, axis=0),
                np.array(matched_labels),
                np.array(mnist_image_indices, dtype=np.int64))
    
    # Match train and test sets
    print("\nMatching training datasets...")
    train_images, train_audio, train_labels, train_mnist_indices = match_datasets_by_label(
        mnist_train_images, mnist_train_labels, train_audio, train_labels, 
        n_samples_per_class=6000, save_mapping=False, seed=args.seed  # Don't save individual mapping
    )
    
    print("\nMatching test datasets...")
    test_images, test_audio, test_labels, test_mnist_indices = match_datasets_by_label(
        mnist_test_images, mnist_test_labels, test_audio, test_labels, 
        n_samples_per_class=1000, save_mapping=False, seed=args.seed  # Don't save individual mapping
    )
    
    # Save combined mapping file with both train and test indices
    mapping_prefix = f"avmnist_real{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}"
    mapping_file = f"03_results/models/{mapping_prefix}_mnist_mapping.npz"
    os.makedirs(os.path.dirname(mapping_file), exist_ok=True)
    mapping_data = {
        'mnist_train_indices': np.array(train_mnist_indices, dtype=np.int64),
        'mnist_test_indices': np.array(test_mnist_indices, dtype=np.int64),
        'train_labels': np.array(train_labels, dtype=np.int64),
        'test_labels': np.array(test_labels, dtype=np.int64),
        'seed': args.seed,
        'full_spectrum': args.full_spectrum
    }
    np.savez(mapping_file, **mapping_data)
    print(f"Saved combined MNIST-AVMNIST mapping to: {mapping_file}")
    
    # Save a few raw training images for debugging (before normalization)
    os.makedirs('03_results/plots', exist_ok=True)
    def save_image_grid(imgs, out_path, n=16):
        # imgs expected shape: (N, 28, 28) or (N, 1, 28, 28) or (N, 784)
        import numpy as _np
        imgs = _np.array(imgs)
        if imgs.ndim == 2 and imgs.shape[1] == 784:
            imgs = imgs.reshape(-1, 28, 28)
        if imgs.ndim == 4 and imgs.shape[1] == 1:
            imgs = imgs[:, 0]
        n = min(n, imgs.shape[0])
        fig, axes = plt.subplots(4, 4, figsize=(6,6))
        for i, ax in enumerate(axes.flatten()):
            ax.imshow(imgs[i], cmap='gray')
            ax.axis('off')
        plt.tight_layout()
        fig.savefig(out_path, dpi=150)
        plt.close(fig)

    # Reshape images to flat vectors for the model (28x28 -> 784)
    train_images_flat = train_images.reshape(train_images.shape[0], -1)
    test_images_flat = test_images.reshape(test_images.shape[0], -1)
    
    plot_prefix = f"avmnist_real{'_fullspec' if args.full_spectrum else ''}_train_seed{args.seed}"
    save_image_grid(train_images, f'03_results/plots/{plot_prefix}.png')
    
    print(f"\nFinal dataset shapes:")
    print(f"Train Images: {train_images_flat.shape}")
    print(f"Train Audio: {train_audio.shape}")
    print(f"Train Labels: {train_labels.shape}")
    print(f"Test Images: {test_images_flat.shape}")
    print(f"Test Audio: {test_audio.shape}")
    print(f"Test Labels: {test_labels.shape}")
    
    # show me min, max, mean, std of train_images and train_audio
    print(f"\nTrain Images - min: {train_images_flat.min()}, max: {train_images_flat.max()}, mean: {train_images_flat.mean()}, std: {train_images_flat.std()}")
    print(f"Train Audio - min: {train_audio.min()}, max: {train_audio.max()}, mean: {train_audio.mean()}, std: {train_audio.std()}")

    n_train_samples = train_images_flat.shape[0]
    data = [torch.FloatTensor(np.concatenate([train_images_flat, test_images_flat], axis=0)),
            torch.FloatTensor(np.concatenate([train_audio, test_audio], axis=0))]
    n_samples = data[0].shape[0]

    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    
    rank_reduction_frequency = 10
    rank_reduction_threshold = 0.01
    early_stopping = 50
    patience = 10
    r_square_threshold = 0.05

    class Args:
        def __init__(self):
            # latent
            self.latent_dim = 200

            # Training parameters
            self.batch_size = 512
            self.lr = 1e-3
            self.weight_decay = 2e-5
            self.dropout = 0.0
            self.ae_depth = 2
            self.ae_width = 0.5
            self.epochs = 5000
            
            # Rank reduction parameters
            self.rank_or_sparse = 'rank'
            
            # GPU parameters
            self.num_workers = 8
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu

    train_args = Args()

    # --- 2. Define the model's dimensions ---
    # Input: (Image: 28*28=784, Audio: 112)
    # Latent: (Image-specific, Audio-specific, Shared)
    #input_dimensions = (784, 112)
    #latent_dimensions = (16, 8, 32) # e.g., 16 for image, 8 for audio, 32 for shared

    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae_with_pretrained(
        data, 
        n_train_samples,
        train_args.latent_dim, 
        DEVICE,
        train_args,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=early_stopping,
        initial_rank_ratio=1.0,
        rank_reduction_frequency=rank_reduction_frequency,
        rank_reduction_threshold=rank_reduction_threshold,
        warmup_epochs=early_stopping,
        patience=patience,
        min_rank=1,
        r_square_threshold=r_square_threshold,
        threshold_type='absolute',
        compressibility_type='direct',
        verbose=True,
        compute_jacobian=False,
        sharedwhenall=False,
        pretrained_name=f"avmnist_real{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}",
        lr_schedule='step',
        decision_metric='ExVarScore',
        full_spectrum=args.full_spectrum
    )
    # save the final model
    model_prefix = f"avmnist_real{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}"
    torch.save(model.state_dict(), f"03_results/models/{model_prefix}_final_model.pth")
    # also save the reps as numpy arrays
    for i, rep in enumerate(reps):
        np.save(f"03_results/models/{model_prefix}_rep{i}.npy", rep.cpu().numpy())
    
    # also save the rank history
    pd.DataFrame(rank_history).to_csv(f"03_results/models/{model_prefix}_rank_history.csv", index=False)
    
    print(f"\nExperiment completed! Model and results saved with prefix '{model_prefix}'")
