import torch
import torch.nn.functional as F
from sklearn.model_selection import KFold
from kernels import get_kernel_matrix
from models import PXGL_GNN
from utils.utils import evaluate, visualize_embeddings

def train_gnn(model, optimizer, train_loader, val_loader, device, epochs, pattern_subgraphs):
    """
    Train the PXGL-GNN model.
    
    Args:
        model (torch.nn.Module): PXGL-GNN model.
        optimizer (torch.optim.Optimizer): Optimizer for the model.
        train_loader (dgl.dataloading.GraphDataLoader): Data loader for training.
        val_loader (dgl.dataloading.GraphDataLoader): Data loader for validation.
        device (torch.device): Device to run the model on.
        epochs (int): Number of training epochs.
        pattern_subgraphs (dict): Dictionary containing sampled subgraphs for each pattern.
        
    Returns:
        best_val_acc (float): Best validation accuracy achieved during training.
        best_val_f1 (float): Best validation F1 score achieved during training.
        best_model_state (dict): State dictionary of the best model.
    """
    best_val_acc = 0
    best_val_f1 = 0
    best_model_state = None
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_graph, labels in train_loader:
            batch_graph = batch_graph.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            logits = model(batch_graph, batch_graph.ndata['feat'], pattern_subgraphs)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        train_acc, train_f1 = evaluate(model, train_loader, device, pattern_subgraphs)
        val_acc, val_f1 = evaluate(model, val_loader, device, pattern_subgraphs)
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_val_f1 = val_f1
            best_model_state = model.state_dict()
        
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {total_loss/len(train_loader):.4f} | "
              f"Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f} | "
              f"Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")
    
    return best_val_acc, best_val_f1, best_model_state

def train_kernel(graphs, labels, pattern_subgraphs, weights):
    """
    Train the graph kernel method.
    
    Args:
        graphs (list): List of networkx graphs.
        labels (list): Corresponding labels for each graph.
        pattern_subgraphs (dict): Dictionary containing sampled subgraphs for each pattern.
        weights (list): Weights for each pattern.
        
    Returns:
        best_acc (float): Best accuracy achieved during cross-validation.
        best_f1 (float): Best F1 score achieved during cross-validation.
    """
    kf = KFold(n_splits=10, shuffle=True, random_state=42)
    accs = []
    f1s = []
    
    for train_index, test_index in kf.split(graphs):
        train_graphs = [graphs[i] for i in train_index]
        train_labels = [labels[i] for i in train_index]
        test_graphs = [graphs[i] for i in test_index]
        test_labels = [labels[i] for i in test_index]
        
        train_kernel_matrix = get_kernel_matrix(train_graphs, pattern_subgraphs, weights)
        test_kernel_matrix = get_kernel_matrix(test_graphs, pattern_subgraphs, weights)
        
        # Train an SVM classifier using the kernel matrix
        from sklearn.svm import SVC
        clf = SVC(kernel='precomputed')
        clf.fit(train_kernel_matrix, train_labels)
        
        # Predict labels for the test set
        test_preds = clf.predict(test_kernel_matrix)
        acc, f1 = evaluate(test_labels, test_preds)
        
        accs.append(acc)
        f1s.append(f1)
    
    best_acc = max(accs)
    best_f1 = max(f1s)
    return best_acc, best_f1

def evaluate(model, data_loader, device, pattern_subgraphs):
    """
    Evaluate the model on the given data loader.
    
    Args:
        model (torch.nn.Module): PXGL-GNN model.
        data_loader (dgl.dataloading.GraphDataLoader): Data loader for evaluation.
        device (torch.device): Device to run the model on.
        pattern_subgraphs (dict): Dictionary containing sampled subgraphs for each pattern.
        
    Returns:
        acc (float): Accuracy score.
        f1 (float): F1 score.
    """
    model.eval()
    preds = []
    labels = []
    
    with torch.no_grad():
        for batch_graph, batch_labels in data_loader:
            batch_graph = batch_graph.to(device)
            batch_labels = batch_labels.to(device)
            logits = model(batch_graph, batch_graph.ndata['feat'], pattern_subgraphs)
            batch_preds = logits.argmax(dim=1)
            preds.append(batch_preds.cpu().numpy())
            labels.append(batch_labels.cpu().numpy())
    
    preds = np.concatenate(preds)
    labels = np.concatenate(labels)
    acc, f1 = evaluate(labels, preds)
    return acc, f1