import torch
import torch.nn as nn
import sys
import os
import argparse
import numpy as np
import random
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST

# Add src to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from src.functions.pretrain_avmnist import (
    ImageAutoencoder, 
    AudioAutoencoder, 
    train_unimodal_autoencoder
)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Pretrain individual autoencoders for each modality")
    parser.add_argument('--modality', type=str, required=True, choices=['image', 'audio'],
                       help='Which modality to train: image or audio')
    parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use')
    parser.add_argument('--full_spectrum', action='store_true', 
                       help='Use full 112x112 audio spectrogram instead of averaging to 112')
    parser.add_argument('--latent_dim', type=int, default=500, 
                       help='Latent dimension for autoencoder')
    parser.add_argument('--epochs', type=int, default=1000, help='Maximum number of epochs')
    parser.add_argument('--early_stopping', type=int, default=50, help='Early stopping patience')
    parser.add_argument('--batch_size', type=int, default=512, help='Batch size')
    parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate')
    parser.add_argument('--lr_min', type=float, default=1e-6, help='Minimum learning rate for cosine annealing')
    parser.add_argument('--weight_decay', type=float, default=2e-5, help='Weight decay')
    parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers')
    args = parser.parse_args()

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)

    # Setup device
    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {DEVICE}")
    print(f"Training modality: {args.modality}")
    if args.modality == 'audio':
        print(f"Full spectrum mode: {args.full_spectrum}")

    # --- 1. Load Data ---
    print("\n" + "="*80)
    print(f"Loading {args.modality} data...")
    print("="*80)
    
    if args.modality == 'image':
        # Load MNIST images
        mnist_train = MNIST(root='01_data/processed/MNIST', train=True, download=True)
        mnist_test = MNIST(root='01_data/processed/MNIST', train=False, download=True)
        
        train_data = mnist_train.data.numpy().astype('float32') / 255.0
        test_data = mnist_test.data.numpy().astype('float32') / 255.0
        
        # Reshape to flat vectors (28x28 -> 784)
        train_data = train_data.reshape(train_data.shape[0], -1)
        test_data = test_data.reshape(test_data.shape[0], -1)
        
        print(f"MNIST train images shape: {train_data.shape}")
        print(f"MNIST test images shape: {test_data.shape}")
        
    elif args.modality == 'audio':
        # Load audio data
        data_dir = "01_data/avmnist_data_from_source"
        
        train_data = np.load(os.path.join(data_dir, "audio/train_data.npy"))
        test_data = np.load(os.path.join(data_dir, "audio/test_data.npy"))
        
        print(f"Audio train shape: {train_data.shape}")
        print(f"Audio test shape: {test_data.shape}")
        
        # If not using full spectrum, average the audio data
        if not args.full_spectrum:
            train_data = train_data.mean(axis=1)  # (N, 112, 112) -> (N, 112)
            test_data = test_data.mean(axis=1)    # (N, 112, 112) -> (N, 112)
            print(f"After averaging - Train: {train_data.shape}, Test: {test_data.shape}")

    # Concatenate train and test data
    all_data = np.concatenate([train_data, test_data], axis=0)
    all_data_tensor = torch.FloatTensor(all_data)

    print(f"\nCombined dataset size: {all_data_tensor.shape}")

    # --- 2. Initialize and Train Autoencoder ---
    print("\n" + "="*80)
    print(f"TRAINING {args.modality.upper()} AUTOENCODER")
    print("="*80)
    
    if args.modality == 'image':
        model = ImageAutoencoder(latent_dim=args.latent_dim).to(DEVICE)
        print(model)
        print(f"\nImage Autoencoder:")
        print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"  Latent dimension: {args.latent_dim}")
    else:  # audio
        model = AudioAutoencoder(
            latent_dim=args.latent_dim, 
            full_spectrum=args.full_spectrum
        ).to(DEVICE)
        print(model)
        print(f"\nAudio Autoencoder:")
        print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"  Latent dimension: {args.latent_dim}")
        print(f"  Full spectrum: {args.full_spectrum}")
    
    model, train_losses, val_losses = train_unimodal_autoencoder(
        model=model,
        data=all_data_tensor,
        device=DEVICE,
        epochs=args.epochs,
        early_stopping=args.early_stopping,
        lr=args.lr,
        batch_size=args.batch_size,
        wd=args.weight_decay,
        patience=args.early_stopping,
        verbose=True,
        num_workers=args.num_workers,
        lr_min=args.lr_min
    )

    # --- 3. Save Model ---
    print("\n" + "="*80)
    print("SAVING PRETRAINED MODEL")
    print("="*80)
    
    os.makedirs("03_results/models/pretrained_unimodal", exist_ok=True)
    
    model_prefix = f"audiomnist{'_fullspec' if args.full_spectrum and args.modality == 'audio' else ''}_rseed-{args.seed}"
    
    save_path = f"03_results/models/pretrained_unimodal/{model_prefix}_{args.modality}_ae.pth"
    save_dict = {
        'encoder_state_dict': model.encoder.state_dict(),
        'decoder_state_dict': model.decoder.state_dict(),
        'latent_dim': args.latent_dim,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'final_val_loss': val_losses[-1],
        'modality': args.modality
    }
    
    if args.modality == 'audio':
        save_dict['full_spectrum'] = args.full_spectrum
    
    torch.save(save_dict, save_path)
    print(f"✓ Saved {args.modality} autoencoder to: {save_path}")
    print(f"  Final validation loss: {val_losses[-1]:.6f}")

    # --- 4. Plot Training Curves ---
    print("\n" + "="*80)
    print("PLOTTING TRAINING CURVES")
    print("="*80)
    
    os.makedirs("03_results/plots", exist_ok=True)
    
    fig, ax = plt.subplots(1, 1, figsize=(8, 5))
    
    ax.plot(train_losses, label='Train', alpha=0.7)
    ax.plot(val_losses, label='Val', alpha=0.7)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('BCE Loss')
    ax.set_title(f'{args.modality.capitalize()} Autoencoder Training')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plot_path = f"03_results/plots/{model_prefix}_{args.modality}_pretraining_curves.png"
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved training curves to: {plot_path}")
    plt.close()

    # --- 5. Plot Reconstructions ---
    print("\n" + "="*80)
    print("PLOTTING RECONSTRUCTIONS")
    print("="*80)
    
    model.eval()
    with torch.no_grad():
        # Split data back into train and test for visualization
        n_train = train_data.shape[0]
        train_tensor = all_data_tensor[:n_train]
        test_tensor = all_data_tensor[n_train:]
        
        # Sample 8 examples from train and test
        n_samples = 8
        train_indices = np.random.choice(len(train_tensor), n_samples, replace=False)
        test_indices = np.random.choice(len(test_tensor), n_samples, replace=False)
        
        train_samples = train_tensor[train_indices].to(DEVICE)
        test_samples = test_tensor[test_indices].to(DEVICE)
        
        train_recons, _ = model(train_samples)
        test_recons, _ = model(test_samples)
        
        # Move to CPU for plotting
        train_samples = train_samples.cpu().numpy()
        train_recons = train_recons.cpu().numpy()
        test_samples = test_samples.cpu().numpy()
        test_recons = test_recons.cpu().numpy()
    
    if args.modality == 'image':
        # Plot images as 28x28
        fig, axes = plt.subplots(4, n_samples, figsize=(n_samples * 1.5, 6))
        
        for i in range(n_samples):
            # Train original
            axes[0, i].imshow(train_samples[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
            axes[0, i].axis('off')
            if i == 0:
                axes[0, i].set_title('Original', fontsize=9, pad=3)
            
            # Train reconstruction
            axes[1, i].imshow(train_recons[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
            axes[1, i].axis('off')
            if i == 0:
                axes[1, i].set_title('Reconstruction', fontsize=9, pad=3)
            
            # Test original
            axes[2, i].imshow(test_samples[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
            axes[2, i].axis('off')
            if i == 0:
                axes[2, i].set_title('Original', fontsize=9, pad=3)
            
            # Test reconstruction
            axes[3, i].imshow(test_recons[i].reshape(28, 28), cmap='gray', vmin=0, vmax=1)
            axes[3, i].axis('off')
            if i == 0:
                axes[3, i].set_title('Reconstruction', fontsize=9, pad=3)
        
        # Add row labels on the left
        axes[0, 0].text(-0.3, 0.5, 'Train', transform=axes[0, 0].transAxes, 
                        fontsize=11, fontweight='bold', va='center', ha='right')
        axes[2, 0].text(-0.3, 0.5, 'Test', transform=axes[2, 0].transAxes, 
                        fontsize=11, fontweight='bold', va='center', ha='right')
        
        plt.suptitle(f'Image Autoencoder Reconstructions - Train & Test (Seed {args.seed})', fontsize=12, y=0.99)
        
    elif args.modality == 'audio':
        if args.full_spectrum:
            # Plot spectrograms as 112x112
            fig, axes = plt.subplots(4, n_samples, figsize=(n_samples * 1.5, 6))
            
            for i in range(n_samples):
                # Train original
                axes[0, i].imshow(train_samples[i].reshape(112, 112), cmap='viridis', aspect='auto', vmin=0, vmax=1)
                axes[0, i].axis('off')
                if i == 0:
                    axes[0, i].set_title('Original', fontsize=9, pad=3)
                
                # Train reconstruction
                axes[1, i].imshow(train_recons[i].reshape(112, 112), cmap='viridis', aspect='auto', vmin=0, vmax=1)
                axes[1, i].axis('off')
                if i == 0:
                    axes[1, i].set_title('Reconstruction', fontsize=9, pad=3)
                
                # Test original
                axes[2, i].imshow(test_samples[i].reshape(112, 112), cmap='viridis', aspect='auto', vmin=0, vmax=1)
                axes[2, i].axis('off')
                if i == 0:
                    axes[2, i].set_title('Original', fontsize=9, pad=3)
                
                # Test reconstruction
                axes[3, i].imshow(test_recons[i].reshape(112, 112), cmap='viridis', aspect='auto', vmin=0, vmax=1)
                axes[3, i].axis('off')
                if i == 0:
                    axes[3, i].set_title('Reconstruction', fontsize=9, pad=3)
            
            # Add row labels on the left
            axes[0, 0].text(-0.15, 0.5, 'Train', transform=axes[0, 0].transAxes, 
                            fontsize=11, fontweight='bold', va='center', ha='right')
            axes[2, 0].text(-0.15, 0.5, 'Test', transform=axes[2, 0].transAxes, 
                            fontsize=11, fontweight='bold', va='center', ha='right')
            
            plt.suptitle(f'Audio Autoencoder Reconstructions - Full Spectrum, Train & Test (Seed {args.seed})', fontsize=12, y=0.99)
        else:
            # Plot 1D audio signals
            fig, axes = plt.subplots(4, n_samples, figsize=(n_samples * 2, 8))
            
            for i in range(n_samples):
                # Train original
                axes[0, i].plot(train_samples[i], 'b-', alpha=0.7, linewidth=0.5)
                axes[0, i].set_ylim([0, 1])
                axes[0, i].set_xlim([0, 112])
                axes[0, i].axis('off')
                if i == 0:
                    axes[0, i].set_title('Original', fontsize=9, pad=3)
                
                # Train reconstruction
                axes[1, i].plot(train_recons[i], 'r-', alpha=0.7, linewidth=0.5)
                axes[1, i].set_ylim([0, 1])
                axes[1, i].set_xlim([0, 112])
                axes[1, i].axis('off')
                if i == 0:
                    axes[1, i].set_title('Reconstruction', fontsize=9, pad=3)
                
                # Test original
                axes[2, i].plot(test_samples[i], 'b-', alpha=0.7, linewidth=0.5)
                axes[2, i].set_ylim([0, 1])
                axes[2, i].set_xlim([0, 112])
                axes[2, i].axis('off')
                if i == 0:
                    axes[2, i].set_title('Original', fontsize=9, pad=3)
                
                # Test reconstruction
                axes[3, i].plot(test_recons[i], 'r-', alpha=0.7, linewidth=0.5)
                axes[3, i].set_ylim([0, 1])
                axes[3, i].set_xlim([0, 112])
                axes[3, i].axis('off')
                if i == 0:
                    axes[3, i].set_title('Reconstruction', fontsize=9, pad=3)
            
            # Add row labels on the left
            axes[0, 0].text(-0.15, 0.5, 'Train', transform=axes[0, 0].transAxes, 
                            fontsize=11, fontweight='bold', va='center', ha='right')
            axes[2, 0].text(-0.15, 0.5, 'Test', transform=axes[2, 0].transAxes, 
                            fontsize=11, fontweight='bold', va='center', ha='right')
            
            plt.suptitle(f'Audio Autoencoder Reconstructions - Averaged, Train & Test (Seed {args.seed})', fontsize=12, y=0.99)
    
    plt.tight_layout()
    recon_plot_path = f"03_results/plots/{model_prefix}_{args.modality}_reconstructions.png"
    plt.savefig(recon_plot_path, dpi=150, bbox_inches='tight')
    print(f"✓ Saved reconstruction plots to: {recon_plot_path}")
    plt.close()

    print("\n" + "="*80)
    print("PRETRAINING COMPLETE!")
    print("="*80)
    print(f"\nSummary:")
    print(f"  Modality: {args.modality}")
    print(f"  Seed: {args.seed}")
    if args.modality == 'audio':
        print(f"  Full spectrum: {args.full_spectrum}")
    print(f"  Latent dim: {args.latent_dim}")
    print(f"  Final val loss: {val_losses[-1]:.6f}")
    print(f"\nOutputs saved:")
    print(f"  Model: {save_path}")
    print(f"  Training curves: {plot_path}")
    print(f"  Reconstructions: {recon_plot_path}")
