import torch
import numpy as np
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_score, recall_score, f1_score
from explainer.explainer_utils import create_edge_embeds, sample_graph

def evaluate_auc(loader, explainer, model, device, training=False):    
    all_gt = []
    all_pred = []
    
    with torch.no_grad():
        for batch in loader:
            explainer.eval()
            model.eval()
            batch = batch.to(device)
            
            _, _, node_embeds = model(batch)

            edge_embeds = create_edge_embeds(batch.edge_index, node_embeds).unsqueeze(dim=0)
            sampling_weights = explainer(edge_embeds)  # [1, E, 1]
            mask = sample_graph(sampling_weights, device, training=False).squeeze()

            edge_gt = batch.edge_gt.squeeze()
            
            all_gt.extend(edge_gt.detach().cpu().numpy())
            all_pred.extend(mask.detach().cpu().numpy())
    
    roc_auc = roc_auc_score(all_gt, all_pred)
    return roc_auc

def evaluate_auc_class(loader, classifier, device, model=None):
    all_gt = []
    all_pred = []
    
    with torch.no_grad():
        for batch in loader:
            classifier.eval()
            batch = batch.to(device)
                        
            if model is not None:
                prob = classifier(batch, model)
            else:
                prob = classifier(batch)    
            # prob = classifier(batch)    
            edge_gt = batch.edge_gt.squeeze()
            
            all_gt.extend(edge_gt.detach().cpu().numpy())
            all_pred.extend(prob.detach().cpu().numpy())
    
    roc_auc = roc_auc_score(all_gt, all_pred)
    
    return roc_auc