###
# Post-training script for AVMNIST model
# Loads the final model from 054_avmnist.py and continues training without rank reduction
###

import torch
import torch.nn as nn
import sys
import os
import argparse
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

# add src to path
project_root = Path(__file__).parent.parent.absolute()
sys.path.append(str(project_root))

from src.functions.train_avmnist import posttrain_overcomplete_ae

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Post-train AVMNIST model")
    parser.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id to use.')
    args = parser.parse_args()

    # Load the same data as used in training
    data_dir = "01_data/processed/avmnist"
    
    # Load from the .npy files as per the MultiBench source code
    train_images = np.load(os.path.join(data_dir, "image/train_data.npy"))
    train_audio = np.load(os.path.join(data_dir, "audio/train_data.npy")) / 255.0
    train_labels = np.load(os.path.join(data_dir, "train_labels.npy"))
    
    # Normalize images (same as in 054_avmnist.py)
    train_images = (train_images - train_images.min()) / (train_images.max() - train_images.min())
    
    test_images = np.load(os.path.join(data_dir, "image/test_data.npy"))
    test_images = (test_images - test_images.min()) / (test_images.max() - test_images.min())
    test_audio = np.load(os.path.join(data_dir, "audio/test_data.npy")) / 255.0
    test_labels = np.load(os.path.join(data_dir, "test_labels.npy"))

    n_train_samples = train_images.shape[0]
    data = [torch.FloatTensor(np.concatenate([train_images, test_images], axis=0)),
            torch.FloatTensor(np.concatenate([train_audio, test_audio], axis=0))]
    n_samples = data[0].shape[0]

    DEVICE = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    
    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    random.seed(args.seed)
    
    print(f"Data loaded: {n_samples} total samples, {n_train_samples} for training")
    print(f"Using device: {DEVICE}")

    # Define Args class to match the original training setup
    class Args:
        def __init__(self):
            # latent
            self.latent_dim = 100

            # Training parameters
            self.batch_size = 512
            self.lr = 1e-5
            self.weight_decay = 0
            self.dropout = 0.0
            self.ae_depth = 2
            self.ae_width = 0.5
            self.epochs = 1000
            self.early_stopping = 50
            
            # Rank reduction parameters (not used in post-training)
            self.rank_or_sparse = 'rank'
            
            # GPU parameters
            self.num_workers = 8
            self.multi_gpu = False
            self.gpu_ids = ''
            self.gpu = args.gpu

    train_args = Args()

    # Path to the saved model from 054_avmnist.py
    model_path = f"03_results/models/avmnist_rseed-{args.seed}_final_model.pth"
    
    if not os.path.exists(model_path):
        print(f"Error: Model file not found at {model_path}")
        print("Please run 054_avmnist.py first to create the initial model.")
        sys.exit(1)

    print(f"Loading model from: {model_path}")
    print("="*80)
    print("POST-TRAINING CONFIGURATION:")
    print("- Architecture will be FROZEN (ranks determined during initial training)")
    print("- Only model weights will be fine-tuned")
    print("- Training will continue from the saved model state")
    print("="*80)
    
    # Post-train the model
    model, train_losses, val_losses = posttrain_overcomplete_ae(
        model_path=model_path,
        data=data,
        n_samples_train=n_train_samples,
        device=DEVICE,
        args=train_args,
        epochs=train_args.epochs,
        early_stopping=train_args.early_stopping,
        lr=train_args.lr,
        batch_size=train_args.batch_size,
        wd=train_args.weight_decay,
        patience=10,
        verbose=True,
        recon_loss_balancing=False,
        paired=False,
        #lr_schedule='exponential',
        model_name=f"avmnist_rseed-{args.seed}",
        trained_ranks=[8,1,5]
    )

    print("Post-training completed!")
    print(f"Final training loss: {train_losses[-1]:.4f}")
    print(f"Final validation loss: {val_losses[-1]:.4f}")
    
    # Save the loss curves
    loss_data = {
        'epoch': list(range(len(train_losses))),
        'train_loss': train_losses,
        'val_loss': val_losses
    }
    
    loss_df = pd.DataFrame(loss_data)
    loss_csv_path = f"03_results/models/avmnist_rseed-{args.seed}_posttrain_losses.csv"
    loss_df.to_csv(loss_csv_path, index=False)
    print(f"Loss curves saved to: {loss_csv_path}")
    
    # Plot the loss curves
    plt.figure(figsize=(10, 6))
    plt.plot(loss_data['epoch'], loss_data['train_loss'], 'b-', label='Training Loss', linewidth=2)
    plt.plot(loss_data['epoch'], loss_data['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'AVMNIST Post-Training Loss Curves (Seed {args.seed})')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plot_path = f"03_results/plots/avmnist_rseed-{args.seed}_posttrain_losses.png"
    os.makedirs(os.path.dirname(plot_path), exist_ok=True)
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Loss plot saved to: {plot_path}")
    
    # Generate new representations with the post-trained model
    print("Generating representations with post-trained model...")
    
    # Calculate latent representations in batches
    final_ranks = [layer.active_dims for layer in model.adaptive_layers]
    reps = [torch.empty((n_samples, final_ranks[i]), device=DEVICE) for i in range(len(final_ranks))]
    model.eval()
    
    batch_size = 512
    with torch.no_grad():
        for i in range(0, n_samples, batch_size):
            end_idx = min(i + batch_size, n_samples)
            x_batch = [data[j][i:end_idx].to(DEVICE) for j in range(len(data))]
            x_batch[1] = torch.mean(x_batch[1], dim=1)
            
            # Encode to get representations
            if hasattr(model, 'module'):  # DataParallel model
                batch_reps = model.module.encode(x_batch)
            else:
                batch_reps = model.encode(x_batch)
            
            batch_rep_list = [batch_reps[0]] + [batch_reps[1][j] for j in range(len(batch_reps[1]))]
                
            for j in range(len(reps)):
                reps[j][i:end_idx,:] = batch_rep_list[j][:,:final_ranks[j]].cpu()
            
            del x_batch, batch_reps
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    # Save the new representations
    for i, rep in enumerate(reps):
        rep_path = f"03_results/models/avmnist_rseed-{args.seed}_posttrain_rep{i}.npy"
        np.save(rep_path, rep.cpu().numpy())
        print(f"Saved representation {i} to: {rep_path}")
    
    print(f"Post-training complete! New representations shapes: {[rep.shape for rep in reps]}")
    print("You can now run the plotting script with these new representations.")
