###
# using AVMNIST and an architecture based on the MVAE from MultiBench
# https://github.com/pliang279/MultiBench/blob/main/examples/multimedia/avmnist_MVAE_mixed.py
###

import torch
import torch.nn as nn
import sys
import os
import requests
import tarfile
import argparse
#from torch.utils.data import TensorDataset, DataLoader, random_split, Dataset
import numpy as np
import random
from sklearn.metrics import silhouette_score
import matplotlib.pyplot as plt
import pandas as pd
#from torchvision.datasets import MNIST
#from sklearn.decomposition import PCA

# add src to path
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.models.larrp_unimodal import AdaptiveRankReducedLinear
from src.functions.train_avmnist import train_overcomplete_ae_with_pretrained

def prepare_data_from_archive(data_dir="avmnist_data"):
    """
    Checks for a manually downloaded 'avmnist.tar.gz' and extracts it.
    """
    os.makedirs(data_dir, exist_ok=True)
    archive_path = os.path.join(data_dir, "avmnist.tar.gz")

    # Check if data is already extracted
    if os.path.exists(os.path.join(data_dir, "image")):
        print("AVMNIST data already extracted.")
        return

    # Check if the user has manually placed the archive
    if not os.path.exists(archive_path):
        print("\n--- ACTION REQUIRED ---")
        print(f"Could not find '{archive_path}'.")
        print("Please manually download 'avmnist.tar.gz' and place it in the 'avmnist_data' directory.")
        print("You can find it here: https://drive.google.com/file/d/1KvKynJJca5tDtI5Mmp6CoRh9pQywH8Xp/view?usp=sharing")
        print("-----------------------\n")
        sys.exit(1) # Exit because we cannot proceed without the data.

    print(f"Found '{archive_path}'. Extracting files...")
    # Use tarfile to extract the .tar.gz archive
    try:
        with tarfile.open(archive_path, "r:gz") as tar:
            tar.extractall(path=data_dir)
        print("Extraction complete.")
    except tarfile.ReadError as e:
        print(f"Error extracting the tar.gz file: {e}")
        print("The file may be corrupted. Please try downloading it again.")
        sys.exit(1)

"""
def get_manual_avmnist_dataloader(batch_size=64, data_dir="avmnist_data"):
    '''
    Manually downloads, extracts, and prepares the AVMNIST dataset.
    Returns train, validation, and test dataloaders.
    '''
    # This function now handles extraction from a manually placed .tar.gz file
    prepare_data_from_archive(data_dir)
    
    # Load from the .npy files as per the MultiBench source code
    train_images = np.load(os.path.join(data_dir, "image/train_data.npy")) / 255.0
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy")) / 255.0
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    
    test_images = np.load(os.path.join(data_dir, "image/test_data.npy")) / 255.0
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy")) / 255.0
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))

    # Create custom datasets
    train_full_dataset = AVMNISTDataset((train_images, train_audio, train_labels))
    test_dataset = AVMNISTDataset((test_images, test_audio, test_labels))

    # Split training data into train and validation
    train_size = 55000
    val_size = len(train_full_dataset) - train_size
    train_dataset, val_dataset = random_split(train_full_dataset, [train_size, val_size])

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    print("AVMNIST DataLoaders created successfully from manually downloaded files.")
    return train_loader, valid_loader, test_loader
"""

"""
def get_original_images():
    num_components = 16
    train_images_raw = MNIST(root='./01_data/processed/MNIST', train=True, download=True)
    x_train = train_images_raw.data.numpy()
    num_pixels = x_train.shape[1] * x_train.shape[2]
    x_train_flat = x_train.reshape(-1, num_pixels).astype('float32') / 255.0
    pca = PCA(n_components=num_components)
    pca.fit(x_train_flat)
    pca_components = pca.components_  # The "eigendigits"
    mean_image = pca.mean_           # The average digit image
    return pca_components, mean_image
"""

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Train AdaptiveRankReducedAE on AVMNIST")
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use.')
    args = parser.parse_args()

    # --- 1. Get the AVMNIST dataloader using the new manual function ---
    # The dataloader will yield batches of [img_data, audio_data, label]
    #try:
    #    train_loader, valid_loader, test_loader = get_manual_avmnist_dataloader(batch_size=64, data_dir="01_data/processed/avmnist")
    #except Exception as e:
    #    print(f"Could not load data manually. Check your internet connection or file paths.")
    #    print(f"Error: {e}")
    #    sys.exit(1)
    #prepare_data_from_archive(data_dir="01_data/processed")
    # get pca components and mean image
    #pca_components, mean_image = get_original_images()
    
    data_dir = "01_data/processed/avmnist"
    # Load from the .npy files as per the MultiBench source code
    train_images = np.load(os.path.join(data_dir, "image/train_data.npy")) #/ 255.0
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy")) / 255.0
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    # Save a few raw training images for debugging (before normalization)
    os.makedirs('03_results/plots', exist_ok=True)
    def save_image_grid(imgs, out_path, n=16):
        # imgs expected shape: (N, 28, 28) or (N, 1, 28, 28) or (N, 784)
        import numpy as _np
        imgs = _np.array(imgs)
        if imgs.ndim == 2 and imgs.shape[1] == 784:
            imgs = imgs.reshape(-1, 28, 28)
        if imgs.ndim == 4 and imgs.shape[1] == 1:
            imgs = imgs[:, 0]
        n = min(n, imgs.shape[0])
        fig, axes = plt.subplots(4, 4, figsize=(6,6))
        for i, ax in enumerate(axes.flatten()):
            ax.imshow(imgs[i], cmap='gray')
            ax.axis('off')
        plt.tight_layout()
        fig.savefig(out_path, dpi=150)
        plt.close(fig)

    save_image_grid(train_images, f'03_results/plots/avmnist_train_raw_seed{args.seed}.png')

    # normalize images to [0,1] (they go from -255 to 255, so use general min-max normalization)
    train_images = (train_images - train_images.min()) / (train_images.max() - train_images.min())
    #train_images = np.abs(train_images) / 255.0
    save_image_grid(train_images, f'03_results/plots/avmnist_train_normalized_seed{args.seed}.png')

    # show me min, max, mean, std of train_images and train_audio
    print(f"Train Images - min: {train_images.min()}, max: {train_images.max()}, mean: {train_images.mean()}, std: {train_images.std()}")
    print(f"Train Audio - min: {train_audio.min()}, max: {train_audio.max()}, mean: {train_audio.mean()}, std: {train_audio.std()}")

    test_images = np.load(os.path.join(data_dir, "image/test_data.npy")) #/ 255.0
    test_images = (test_images - test_images.min()) / (test_images.max() - test_images.min())
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy")) / 255.0
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))

    n_train_samples = train_images.shape[0]
    data = [torch.FloatTensor(np.concatenate([train_images, test_images], axis=0)),
            torch.FloatTensor(np.concatenate([train_audio, test_audio], axis=0))]
    n_samples = data[0].shape[0]

    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    
    rank_reduction_frequency = 10
    rank_reduction_threshold = 0.01
    early_stopping = 50
    patience = 10
    r_square_threshold = 0.05

    class Args:
        def __init__(self):
            # latent
            self.latent_dim = 100

            # Training parameters
            self.batch_size = 512
            self.lr = 1e-3
            self.weight_decay = 0
            self.dropout = 0.0
            self.ae_depth = 2
            self.ae_width = 0.5
            self.epochs = 5000
            
            # Rank reduction parameters
            self.rank_or_sparse = 'rank'
            
            # GPU parameters
            self.num_workers = 8
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu

    train_args = Args()

    # --- 2. Define the model's dimensions ---
    # Input: (Image: 28*28=784, Audio: 112)
    # Latent: (Image-specific, Audio-specific, Shared)
    #input_dimensions = (784, 112)
    #latent_dimensions = (16, 8, 32) # e.g., 16 for image, 8 for audio, 32 for shared

    model, reps, train_loss, r_squares, rank_history, loss_curves = train_overcomplete_ae_with_pretrained(
        data, 
        n_train_samples,
        train_args.latent_dim, 
        DEVICE,
        train_args,
        epochs=train_args.epochs, 
        lr=train_args.lr, 
        batch_size=train_args.batch_size, 
        ae_depth=train_args.ae_depth, 
        ae_width=train_args.ae_width, 
        dropout=train_args.dropout, 
        wd=train_args.weight_decay,
        early_stopping=early_stopping,
        initial_rank_ratio=1.0,
        rank_reduction_frequency=rank_reduction_frequency,
        rank_reduction_threshold=rank_reduction_threshold,
        warmup_epochs=early_stopping,
        patience=patience,
        min_rank=1,
        r_square_threshold=r_square_threshold,
        threshold_type='absolute',
        compressibility_type='direct',
        verbose=True,
        compute_jacobian=False,
        sharedwhenall=False,
        pretrained_name=f"avmnist_rseed-{args.seed}",
        lr_schedule='linear'
    )
    # save the final model
    torch.save(model.state_dict(), f"03_results/models/avmnist_rseed-{args.seed}_final_model.pth")
    # also save the reps as numpy arrays
    for i, rep in enumerate(reps):
        np.save(f"03_results/models/avmnist_rseed-{args.seed}_rep{i}.npy", rep.cpu().numpy())
    
    # also save the rank history
    pd.DataFrame(rank_history).to_csv(f"03_results/models/avmnist_rseed-{args.seed}_rank_history.csv", index=False)
