import argparse
import torch
import torch.nn.functional as F
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import json
from datetime import datetime

from egnn_model import EGNNWithPE, LaplacianPE

def train(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0
    
    for data in tqdm(train_loader, desc='Training', leave=False):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Forward pass with gradient clipping
        output = model(data)
        
        loss = F.nll_loss(output, data.y)
        
        # Check if loss is valid before backpropagation
        if not torch.isnan(loss).any() and not torch.isinf(loss).any():
            loss.backward()
            
            # Clip gradients for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            total_loss += loss.item() * data.num_graphs
        else:
            print("Warning: NaN or Inf loss detected, skipping batch")
    
    return total_loss / len(train_loader.dataset)

def validate(model, loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    
    for data in tqdm(loader, desc='Validating', leave=False):
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            
            # Handle potential NaN outputs
            if torch.isnan(output).any():
                output = torch.nan_to_num(output, nan=-1e10)
                
            loss = F.nll_loss(output, data.y)
            
            # Only count valid losses
            if not torch.isnan(loss).any() and not torch.isinf(loss).any():
                total_loss += loss.item() * data.num_graphs
                
            pred = output.max(dim=1)[1]
            correct += pred.eq(data.y).sum().item()
    
    accuracy = correct / len(loader.dataset)
    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, accuracy

def plot_training_curves(all_results, save_path='egnn_training_curves.png'):
    # Create figure for loss curves
    plt.figure(figsize=(12, 6))
    
    for config, results in all_results.items():
        epochs = range(1, len(results['train_loss']) + 1)
        plt.plot(epochs, results['train_loss'], 
                label=f'Train Loss {config}', linestyle='-')
        plt.plot(epochs, results['val_loss'], 
                label=f'Val Loss {config}', linestyle='--')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('EGNN Training and Validation Loss for Different Configurations')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

    # Create figure for accuracy curves
    plt.figure(figsize=(12, 6))
    
    for config, results in all_results.items():
        epochs = range(1, len(results['train_acc']) + 1)
        plt.plot(epochs, results['train_acc'], 
                label=f'Train Acc {config}', linestyle='-')
        plt.plot(epochs, results['val_acc'], 
                label=f'Val Acc {config}', linestyle='--')
    
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('EGNN Training and Validation Accuracy for Different Configurations')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path.replace('.png', '_acc.png'))
    plt.close()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--k_values', nargs='+', type=int, default=[3, 8, 16], 
                      help='List of k values for positional encoding dimensions')
    parser.add_argument('--epochs', type=int, default=30, help='Number of epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate')
    parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimension')
    parser.add_argument('--num_layers', type=int, default=3, help='Number of EGNN layers')
    parser.add_argument('--dropout', type=float, default=0.2, help='Dropout rate')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use')
    parser.add_argument('--data_dir', type=str, default='data', help='Data directory')
    parser.add_argument('--with_pos_enc', action='store_true', help='Use positional encoding', default=True)
    parser.add_argument('--with_proj', action='store_true', help='Use edge projectors')
    parser.add_argument('--with_virtual_node', action='store_true', help='Use virtual node')
    parser.add_argument('--norm_features', action='store_true', help='Normalize features', default=True)
    parser.add_argument('--norm_coords', action='store_true', help='Normalize coordinates', default=True)
    parser.add_argument('--coord_weights_clamp', type=float, default=1.0,
                      help='Clamping value for coordinate weights')
    parser.add_argument('--weight_decay', type=float, default=1e-5,
                      help='Weight decay for regularization')
    parser.add_argument('--early_stopping', type=int, default=10, help='Early stopping patience')
    parser.add_argument('--output_dir', type=str, default='results', help='Output directory for results')
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Generate timestamp for unique run identification
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Set device
    device = torch.device(args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu')
    print(f'Using device: {device}')
    
    # For reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    
    # Dictionary to store results for all configurations
    all_results = {}
    final_results = []
    
    # Train with and without coordinate updates
    coord_update_options = [True, False]
    
    # Load datasets just once for efficiency
    transform = T.Compose([
        # LaplacianPE will be applied with different k values later
        T.NormalizeFeatures()
    ])
    
    # Define base transforms without LaplacianPE (will be added per configuration)
    base_transform = T.NormalizeFeatures()
    
    # Run experiments for each configuration
    for update_coords in coord_update_options:
        for k in args.k_values:
            print(f'\n{"="*50}')
            print(f'Training EGNN model with k={k}, update_coords={update_coords}')
            print(f'{"="*50}')
            
            config_name = f"k{k}_coords{'On' if update_coords else 'Off'}"
            
            # Create Laplacian PE transform for this k value
            laplacian_transform = LaplacianPE(
                k=k, 
                with_pos_enc=args.with_pos_enc,
                with_proj=args.with_proj,
                with_virtual=args.with_virtual_node
            )
            
            # Create full transform
            transform = T.Compose([laplacian_transform, base_transform])
            
            # Load datasets with the current transform
            train_dataset = MNISTSuperpixels(
                root=args.data_dir, 
                train=True,
                transform=transform
            )
            
            test_dataset = MNISTSuperpixels(
                root=args.data_dir, 
                train=False,
                transform=transform
            )
            
            # Create data loaders
            train_loader = DataLoader(
                train_dataset, 
                batch_size=args.batch_size, 
                shuffle=True
            )
            test_loader = DataLoader(
                test_dataset, 
                batch_size=args.batch_size
            )
            
            # Initialize model with current configuration
            model = EGNNWithPE(
                num_features=1,  # MNIST has 1 feature per node
                pos_enc_dim=k,
                hidden_dim=args.hidden_dim,
                num_classes=10,  # MNIST has 10 classes
                dropout=args.dropout,
                num_layers=args.num_layers,
                norm_features=args.norm_features,
                norm_coords=args.norm_coords,
                coord_weights_clamp_value=args.coord_weights_clamp,
                update_coords=update_coords  # Set coordinate update flag
            ).to(device)
            
            optimizer = torch.optim.Adam(
                model.parameters(), 
                lr=args.lr, 
                weight_decay=args.weight_decay
            )
            
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, mode='max', factor=0.5, patience=5, verbose=True,
                min_lr=1e-6
            )
            
            # Initialize metrics storage
            all_results[config_name] = {
                'train_loss': [],
                'val_loss': [],
                'train_acc': [],
                'val_acc': []
            }
            
            # Training loop
            best_val_acc = 0
            patience_counter = 0
            max_patience = args.early_stopping
            
            for epoch in range(1, args.epochs + 1):
                # Train
                train_loss = train(model, train_loader, optimizer, device)
                
                # Validate
                val_loss, val_acc = validate(model, test_loader, device)
                train_loss_val, train_acc = validate(model, train_loader, device)
                
                # Update scheduler
                scheduler.step(val_acc)
                
                # Store metrics
                all_results[config_name]['train_loss'].append(train_loss_val)
                all_results[config_name]['val_loss'].append(val_loss)
                all_results[config_name]['train_acc'].append(train_acc)
                all_results[config_name]['val_acc'].append(val_acc)
                
                # Save best model
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save(model.state_dict(), 
                              os.path.join(args.output_dir, f'best_egnn_model_{config_name}.pt'))
                    patience_counter = 0
                else:
                    patience_counter += 1
                
                print(f'Epoch: {epoch:02d}, '
                      f'Train Loss: {train_loss:.4f}, '
                      f'Val Loss: {val_loss:.4f}, '
                      f'Train Acc: {train_acc:.4f}, '
                      f'Val Acc: {val_acc:.4f}, '
                      f'Best Val Acc: {best_val_acc:.4f}')
                
                # Early stopping
                if patience_counter >= max_patience:
                    print(f"Early stopping triggered after {epoch} epochs")
                    break
            
            # Record final results
            final_results.append({
                'k': k,
                'update_coords': update_coords,
                'best_val_accuracy': best_val_acc,
                'epochs_trained': epoch
            })
    
    # Plot combined results
    plot_training_curves(all_results, save_path=os.path.join(args.output_dir, f'egnn_training_curves_{timestamp}.png'))
    
    # Save final results as CSV and JSON
    results_df = pd.DataFrame(final_results)
    results_df.to_csv(os.path.join(args.output_dir, f'egnn_results_{timestamp}.csv'), index=False)
    
    # Save as JSON for easier reading
    with open(os.path.join(args.output_dir, f'egnn_results_{timestamp}.json'), 'w') as f:
        json.dump(final_results, f, indent=2)
    
    # Print final results in a table format
    print("\n\n" + "="*60)
    print("FINAL RESULTS SUMMARY")
    print("="*60)
    print(results_df.to_string(index=False))
    print("="*60)
    
    # Create a visual comparison table with k values as rows and coordinate options as columns
    pivot_table = results_df.pivot_table(
        index='k', 
        columns='update_coords', 
        values='best_val_accuracy',
        aggfunc='max'
    ).round(4)
    
    # Rename columns for clarity
    pivot_table.columns = ['Coords OFF', 'Coords ON']
    
    print("\nACCURACY COMPARISON TABLE:")
    print(pivot_table)
    
    # Save the comparison table
    pivot_table.to_csv(os.path.join(args.output_dir, f'egnn_comparison_table_{timestamp}.csv'))
    
    # Determine best overall configuration
    best_idx = results_df['best_val_accuracy'].idxmax()
    best_config = results_df.iloc[best_idx]
    
    print(f"\nBEST CONFIGURATION:")
    print(f"k = {best_config['k']}, update_coords = {best_config['update_coords']}")
    print(f"Best validation accuracy: {best_config['best_val_accuracy']:.4f}")

if __name__ == '__main__':
    main()