import argparse
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

from models import NetWithPE, LaplacianPE
from dataset_handler import get_dataset

def train(model, train_loader, optimizer, device):
    """Training function"""
    model.train()
    total_loss = 0
    
    for data in tqdm(train_loader, desc='Training', leave=False):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    
    return total_loss / len(train_loader.dataset)

def validate(model, loader, device):
    """Validation function"""
    model.eval()
    total_loss = 0
    correct = 0
    
    for data in tqdm(loader, desc='Validating', leave=False):
        data = data.to(device)
        with torch.no_grad():
            output = model(data)
            loss = F.nll_loss(output, data.y)
            total_loss += loss.item() * data.num_graphs
            pred = output.max(dim=1)[1]
            correct += pred.eq(data.y).sum().item()
    
    accuracy = correct / len(loader.dataset)
    avg_loss = total_loss / len(loader.dataset)
    return avg_loss, accuracy

def plot_training_curves(results, save_path='training_curves.png'):
    """Plot training and validation curves"""
    plt.figure(figsize=(12, 6))
    epochs = range(1, len(results['train_loss']) + 1)
    
    plt.plot(epochs, results['train_loss'], label='Train Loss', linestyle='-')
    plt.plot(epochs, results['val_loss'], label='Val Loss', linestyle='--')
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Metrics')
    plt.legend()
    plt.grid(True)
    
    # Add accuracy on secondary axis
    ax2 = plt.twinx()
    ax2.plot(epochs, results['train_acc'], label='Train Acc', linestyle=':', color='g')
    ax2.plot(epochs, results['val_acc'], label='Val Acc', linestyle='-.', color='r')
    ax2.set_ylabel('Accuracy')
    
    plt.legend(loc='center right')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def get_model(dataset, args, device):
    """Helper function to initialize the appropriate model"""
    num_features = dataset.num_features if hasattr(dataset, 'num_features') else 1
    num_classes = dataset.num_classes if hasattr(dataset, 'num_classes') else 10
    
    model = NetWithPE(
        num_features=num_features,
        pos_enc_dim=args.k,
        num_classes=num_classes
    ).to(device)
    
    return model

def main():
    # Available datasets
    available_datasets = ['mnist', 'PROTEINS', 'MUTAG', 'IMDB-BINARY', 
                         'REDDIT-BINARY', 'COLLAB', 'ENZYMES']
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='mnist', choices=available_datasets,
                       help='Dataset name')
    parser.add_argument('--k', type=int, default=3, help='Dimension of positional encoding')
    parser.add_argument('--epochs', type=int, default=50, help='Number of epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--device', type=str, default='cuda', help='Device to use')
    parser.add_argument('--data_dir', type=str, default='data', help='Data directory')
    parser.add_argument('--with_pos_enc', action='store_true',
                       help='Use positional encoding')
    parser.add_argument('--with_proj', action='store_true',
                       help='Use edge projections')
    parser.add_argument('--with_virtual_node', action='store_true',
                       help='Add virtual node to graphs')
    parser.add_argument('--use_node_attr', action='store_true',
                       help='Use additional node attributes for TU datasets if available')
    
    args = parser.parse_args()
    
    # Set device
    device = torch.device(args.device if torch.cuda.is_available() and args.device == 'cuda' else 'cpu')
    print(f'Using device: {device}')
    print(f'Dataset: {args.dataset}')
    
    # Define transforms
    transform = T.Compose([
        LaplacianPE(k=args.k, 
                   with_pos_enc=args.with_pos_enc,
                   with_proj=args.with_proj,
                   with_virtual=args.with_virtual_node),
        T.NormalizeFeatures()
    ])
    
    # Load dataset
    train_dataset, test_dataset = get_dataset(
        args.dataset,
        args.data_dir,
        transform,
        args.k,
        args.use_node_attr
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=args.batch_size
    )
    
    # Initialize model and optimizer
    model = get_model(train_dataset, args, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    
    # Initialize metrics storage
    results = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    # Training loop
    best_val_acc = 0
    for epoch in range(1, args.epochs + 1):
        # Train
        train_loss = train(model, train_loader, optimizer, device)
        
        # Validate
        val_loss, val_acc = validate(model, test_loader, device)
        train_loss_val, train_acc = validate(model, train_loader, device)
        
        # Store metrics
        results['train_loss'].append(train_loss_val)
        results['val_loss'].append(val_loss)
        results['train_acc'].append(train_acc)
        results['val_acc'].append(val_acc)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
            }, f'best_model_{args.dataset}.pt')
        
        print(f'Epoch: {epoch:02d}, '
              f'Train Loss: {train_loss:.4f}, '
              f'Val Loss: {val_loss:.4f}, '
              f'Train Acc: {train_acc:.4f}, '
              f'Val Acc: {val_acc:.4f}, '
              f'Best Val Acc: {best_val_acc:.4f}')
    
    # Plot results
    plot_training_curves(results, save_path=f'training_curves_{args.dataset}.png')
    
    print(f'\nFinal Results for {args.dataset}:')
    print(f'Best Validation Accuracy: {best_val_acc:.4f}')

if __name__ == '__main__':
    main()