import argparse
import torch
from data_processing import load_dataset, subgraph_sampling, get_dataloader
from models import PXGL_GNN
from experiment import train_gnn, train_kernel
from utils.utils import set_seed, save_results, visualize_embeddings

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='MUTAG', help='Name of the dataset')
    parser.add_argument('--model', type=str, default='gnn', choices=['gnn', 'kernel'], help='Model to use: GNN or Kernel')
    parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimension of the GNN')
    parser.add_argument('--num_layers', type=int, default=3, help='Number of GNN layers')
    parser.add_argument('--dropout', type=float, default=0.5, help='Dropout probability')
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    args = parser.parse_args()
    
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load and preprocess the dataset
    dataset, num_classes = load_dataset(args.dataset)
    
    if args.model == 'gnn':
        # Create data loaders for GNN
        train_loader, val_loader = get_dataloader(dataset, batch_size=args.batch_size)
        
        # Create the PXGL-GNN model
        model = PXGL_GNN(dataset.dim_nfeats, args.hidden_dim, num_classes, args.num_layers, num_patterns=7).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        
        # Train the PXGL-GNN model
        print(f"Training PXGL-GNN on {args.dataset} dataset...")
        best_val_acc, best_val_f1, best_model_state = train_gnn(model, optimizer, train_loader, val_loader, device,
                                                                args.epochs, dataset.pattern_subgraphs)
        print(f"Best Validation Accuracy: {best_val_acc:.4f} | Best Validation F1: {best_val_f1:.4f}")
        
        # Save the trained model
        torch.save(best_model_state, f"{args.dataset}_pxgl_gnn.pth")
        
        # Visualize the learned embeddings
        model.load_state_dict(best_model_state)
        embeddings = []
        labels = []
        with torch.no_grad():
            for graph, label in dataset:
                graph = graph.to(device)
                embedding = model(graph, graph.ndata['feat'], dataset.pattern_subgraphs)
                embeddings.append(embedding.cpu().numpy())
                labels.append(label.item())
        visualize_embeddings(embeddings, labels, f"PXGL-GNN Embeddings ({args.dataset})")
        
    elif args.model == 'kernel':
        # Extract networkx graphs and labels from the dataset
        graphs = [g.to_networkx(node_attrs=['feat']) for g, _ in dataset]
        labels = [label.item() for _, label in dataset]
        
        # Perform subgraph sampling for each graph
        pattern_subgraphs = {pattern: [] for pattern in ['path', 'tree', 'cycle', 'clique', 'graphlet', 'wheel', 'star']}
        for graph in graphs:
            sampled_subgraphs = subgraph_sampling(graph)
            for pattern, subgraphs in sampled_subgraphs.items():
                pattern_subgraphs[pattern].extend(subgraphs)
        
        # Define pattern weights (you can adjust these weights based on your experiments)
        weights = [0.2, 0.1, 0.1, 0.2, 0.2, 0.1, 0.1]
        
        # Train the graph kernel method
        print(f"Training Graph Kernel on {args.dataset} dataset...")
        best_acc, best_f1 = train_kernel(graphs, labels, pattern_subgraphs, weights)
        print(f"Best Accuracy: {best_acc:.4f} | Best F1: {best_f1:.4f}")
        
        # Save the results
        save_results({'accuracy': best_acc, 'f1': best_f1}, f"{args.dataset}_kernel_results.txt")
    
    else:
        raise ValueError(f"Invalid model: {args.model}")

if __name__ == '__main__':
    main()