###
# Post-training script for AVMNIST model with real MNIST images
# Loads the final model from 054_avmnist_real.py and continues training without rank reduction
###

import torch
import torch.nn as nn
import sys
import os
import argparse
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
from torchvision.datasets import MNIST
from pathlib import Path

# add src to path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.functions.train_avmnist import posttrain_overcomplete_ae

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Post-train AVMNIST model")
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use.')
    parser.add_argument('--full_spectrum', action='store_true', help='Use full 112x112 audio spectrogram instead of averaging to 112')
    args = parser.parse_args()

    # Load the same data as used in training (real MNIST + AVMNIST audio)
    print("Loading real MNIST dataset...")
    mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
    mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
    
    # Get MNIST images and labels
    mnist_train_images = mnist_train.data.numpy().astype('float32') / 255.0  # Normalize to [0,1]
    mnist_train_labels = mnist_train.targets.numpy()
    mnist_test_images = mnist_test.data.numpy().astype('float32') / 255.0
    mnist_test_labels = mnist_test.targets.numpy()
    
    # Load audio data from AVMNIST
    data_dir = "01_data/processed/avmnist"
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy")) / 255.0
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy")) / 255.0
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))
    
    # Match MNIST images to AVMNIST labels (same logic as in training script)
    def match_datasets_by_label(images, img_labels, audio, audio_labels, n_samples_per_class=6000, 
                               use_saved_mapping=True, seed=42, dataset_kind='train', full_spectrum=False):
        """
        Match MNIST images to AVMNIST audio by ensuring same label distribution.
        Takes n_samples_per_class samples from each digit class (0-9).
        Uses diverse MNIST images per class to avoid identical thumbnails.
        Requires the exact same mapping as training for reproducibility.
        """
        # Try to load saved mapping - REQUIRED for reproducibility
        mapping_file = f"03_results/models/avmnist_real{'_fullspec' if full_spectrum else ''}_rseed-{seed}_mnist_mapping.npz"
        if not os.path.exists(mapping_file):
            print(f"ERROR: Required MNIST-AVMNIST mapping file not found: {mapping_file}")
            print("Please run 054_avmnist_real.py first to create the mapping.")
            sys.exit(1)

        print(f"Loading saved MNIST-AVMNIST mapping from: {mapping_file}")
        mapping_data = np.load(mapping_file, allow_pickle=True)

        # Prefer per-split indices if available (saved by newer training script)
        if dataset_kind == 'train' and 'mnist_train_indices' in mapping_data:
            mnist_indices = mapping_data['mnist_train_indices']
            saved_labels = mapping_data.get('train_labels', mapping_data.get('labels', None))
        elif dataset_kind == 'test' and 'mnist_test_indices' in mapping_data:
            mnist_indices = mapping_data['mnist_test_indices']
            saved_labels = mapping_data.get('test_labels', mapping_data.get('labels', None))
        elif 'mnist_image_indices' in mapping_data:
            # Older mapping only contained a single list (likely for the training split).
            if dataset_kind == 'train':
                mnist_indices = mapping_data['mnist_image_indices']
                saved_labels = mapping_data.get('labels', None)
            else:
                # Mapping lacks explicit test indices. Create a deterministic mapping for the test split
                # by sampling from the provided `images`/`img_labels` according to the audio label counts.
                print(f"WARNING: Mapping file {mapping_file} does not contain separate test indices. Creating deterministic test mapping from MNIST test split using saved seed.")
                seed_used = int(mapping_data.get('seed', seed))
                rng = np.random.RandomState(seed_used)
                mnist_indices_list = []
                saved_labels_list = []
                for digit in range(10):
                    img_idx_for_digit = np.where(img_labels == digit)[0]
                    n_needed = int(np.sum(audio_labels == digit))
                    if n_needed <= 0:
                        continue
                    if len(img_idx_for_digit) >= n_needed:
                        chosen = rng.choice(img_idx_for_digit, n_needed, replace=False)
                    else:
                        # Not enough unique images; tile and shuffle
                        reps = (n_needed // len(img_idx_for_digit)) + 1
                        tiled = np.tile(img_idx_for_digit, reps)[:n_needed]
                        rng.shuffle(tiled)
                        chosen = tiled
                    mnist_indices_list.extend(chosen.tolist())
                    saved_labels_list.extend([digit] * n_needed)
                mnist_indices = np.array(mnist_indices_list, dtype=np.int64)
                saved_labels = np.array(saved_labels_list, dtype=np.int64)
        else:
            print(f"ERROR: Mapping file {mapping_file} missing expected keys ('mnist_train_indices' or 'mnist_image_indices').")
            sys.exit(1)

        # Validate indices are within bounds for the provided `images` array
        mnist_indices = np.array(mnist_indices, dtype=np.int64)
        if mnist_indices.size == 0:
            print("ERROR: Loaded MNIST index list is empty. Cannot match datasets.")
            sys.exit(1)
        if mnist_indices.max() >= images.shape[0] or mnist_indices.min() < 0:
            print(f"ERROR: MNIST mapping indices out of bounds for the provided images array (max index {mnist_indices.max()}, images size {images.shape[0]}).")
            print("This usually means the mapping refers to a different MNIST split. Ensure the mapping contains per-split indices or run the matching/training script that saved both train and test indices.")
            sys.exit(1)

        # Use the exact same MNIST images as training (or the appropriate split)
        matched_images = images[mnist_indices]
        
        # Still need to match audio (reconstruct the audio matching)
        matched_audio = []
        matched_labels = []
        
        for digit in range(10):
            audio_indices = np.where(audio_labels == digit)[0]
            n_samples = np.sum(saved_labels == digit)
            
            if n_samples > 0 and len(audio_indices) >= n_samples:
                audio_sample_indices = np.random.choice(audio_indices, n_samples, replace=False)
                matched_audio.append(audio[audio_sample_indices])
                matched_labels.extend([digit] * n_samples)
                print(f"Digit {digit}: matched {n_samples} samples (using saved MNIST mapping)")
        
        return (matched_images,
                np.concatenate(matched_audio, axis=0),
                np.array(matched_labels))
    
    # Match train and test sets (using saved mapping from training)
    print("\nMatching training datasets...")
    train_images, train_audio, train_labels = match_datasets_by_label(
        mnist_train_images, mnist_train_labels, train_audio, train_labels, 
        n_samples_per_class=6000, seed=args.seed, dataset_kind='train', full_spectrum=args.full_spectrum
    )
    
    print("\nMatching test datasets...")
    test_images, test_audio, test_labels = match_datasets_by_label(
        mnist_test_images, mnist_test_labels, test_audio, test_labels, 
        n_samples_per_class=1000, seed=args.seed, dataset_kind='test', full_spectrum=args.full_spectrum
    )

    # Quick sanity checks
    if train_images.shape[0] != train_audio.shape[0]:
        print(f"WARNING: train_images ({train_images.shape[0]}) and train_audio ({train_audio.shape[0]}) lengths differ")
    if test_images.shape[0] != test_audio.shape[0]:
        print(f"WARNING: test_images ({test_images.shape[0]}) and test_audio ({test_audio.shape[0]}) lengths differ")
    
    # Reshape images to flat vectors for the model (28x28 -> 784)
    train_images_flat = train_images.reshape(train_images.shape[0], -1)
    test_images_flat = test_images.reshape(test_images.shape[0], -1)

    n_train_samples = train_images_flat.shape[0]
    data = [torch.FloatTensor(np.concatenate([train_images_flat, test_images_flat], axis=0)),
            torch.FloatTensor(np.concatenate([train_audio, test_audio], axis=0))]
    n_samples = data[0].shape[0]

    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    
    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    
    print(f"Data loaded: {n_samples} total samples, {n_train_samples} for training")
    print(f"Using device: {DEVICE}")

    # Define Args class to match the original training setup
    class Args:
        def __init__(self):
            # latent
            self.latent_dim = 200

            # Training parameters
            self.batch_size = 512
            self.lr = 1e-6
            self.weight_decay = 0
            self.dropout = 0.0
            self.ae_depth = 2
            self.ae_width = 0.5
            self.epochs = 100
            self.early_stopping = 50
            
            # Rank reduction parameters (not used in post-training)
            self.rank_or_sparse = 'rank'
            
            # GPU parameters
            self.num_workers = 8
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu

    train_args = Args()

    # Path to the saved model from 054_avmnist_real.py
    model_prefix = f"avmnist_real{'_fullspec' if args.full_spectrum else ''}_rseed-{args.seed}"
    model_path = f"03_results/models/{model_prefix}_final_model.pth"
    
    if not os.path.exists(model_path):
        print(f"Error: Model file not found at {model_path}")
        print("Please run 054_avmnist_real.py first to create the initial model.")
        sys.exit(1)

    print(f"Loading model from: {model_path}")
    print(f"Final dataset shapes:")
    print(f"Train Images: {train_images_flat.shape}")
    print(f"Train Audio: {train_audio.shape}")
    print(f"Test Images: {test_images_flat.shape}")
    print(f"Test Audio: {test_audio.shape}")
    print("="*80)
    print("POST-TRAINING CONFIGURATION (Real MNIST):")
    print("- Using real MNIST handwritten digit images")
    print("- Architecture will be FROZEN (ranks determined during initial training)")
    print("- Only model weights will be fine-tuned")
    print("- Training will continue from the saved model state")
    print("="*80)
    
    # Post-train the model
    model, train_losses, val_losses = posttrain_overcomplete_ae(
        model_path=model_path,
        data=data,
        n_samples_train=n_train_samples,
        device=DEVICE,
        args=train_args,
        epochs=train_args.epochs,
        early_stopping=train_args.early_stopping,
        lr=train_args.lr,
        batch_size=train_args.batch_size,
        wd=train_args.weight_decay,
        patience=10,
        verbose=True,
        recon_loss_balancing=False,
        paired=False,
        #lr_schedule='exponential',
        model_name=model_prefix,
        trained_ranks=[8,6,7],
        full_spectrum=args.full_spectrum
    )

    print("Post-training completed!")
    print(f"Final training loss: {train_losses[-1]:.4f}")
    print(f"Final validation loss: {val_losses[-1]:.4f}")
    
    # Save the loss curves
    loss_data = {
        'epoch': list(range(len(train_losses))),
        'train_loss': train_losses,
        'val_loss': val_losses
    }
    
    loss_df = pd.DataFrame(loss_data)
    loss_csv_path = f"03_results/models/{model_prefix}_posttrain_losses.csv"
    loss_df.to_csv(loss_csv_path, index=False)
    print(f"Loss curves saved to: {loss_csv_path}")
    
    # Plot the loss curves
    plt.figure(figsize=(10, 6))
    plt.plot(loss_data['epoch'], loss_data['train_loss'], 'b-', label='Training Loss', linewidth=2)
    plt.plot(loss_data['epoch'], loss_data['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'AVMNIST Real{" Full Spectrum" if args.full_spectrum else ""} Post-Training Loss Curves (Seed {args.seed})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plot_path = f"03_results/plots/{model_prefix}_posttrain_losses.png"
    os.makedirs(os.path.dirname(plot_path), exist_ok=True)
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Loss plot saved to: {plot_path}")
    
    # Generate new representations with the post-trained model
    print("Generating representations with post-trained model...")
    
    # Calculate latent representations in batches
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=DEVICE) for i in range(len(final_ranks))]
    model.eval()
    
    batch_size = 512
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(DEVICE) for j in range(len(data))]
            x_batch[1] = torch.mean(x_batch[1], dim=1)
            
            # Encode to get representations
            if hasattr(model, 'module'):  # DataParallel model
                batch_reps = model.module.encode(x_batch)
            else:
                batch_reps = model.encode(x_batch)
            
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
                
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Save the new representations
    for i, rep in enumerate(reps):
        rep_path = f"03_results/models/{model_prefix}_posttrain_rep{i}.npy"
        np.save(rep_path, rep.cpu().numpy())
        print(f"Saved representation {i} to: {rep_path}")
    
    print(f"Post-training complete! New representations shapes: {[rep.shape for rep in reps]}")
    print("You can now run the plotting script with these new representations.")
    print(f"\nAll outputs saved with prefix '{model_prefix}'")
    print(f"This allows you to distinguish from {'averaged audio' if not args.full_spectrum else 'full spectrum'} results.")
    print("This allows you to distinguish from eigendigit-based results.")
