import torch
import os
import numpy as np
from sklearn.metrics import average_precision_score, precision_recall_fscore_support, roc_auc_score

def aucPerformance(score, labels):
    roc_auc = roc_auc_score(labels, score)
    ap = average_precision_score(labels, score)
    return roc_auc, ap

def F1Performance(score, target):
    normal_ratio = (target == 0).sum() / len(target)
    score = np.squeeze(score)
    threshold = np.percentile(score, 100 * normal_ratio)
    pred = np.zeros(len(score))
    pred[score > threshold] = 1
    precision, recall, f1, _ = precision_recall_fscore_support(target, pred, average='binary')
    return f1

def calcuate_metric_ad(y_true, y_pred):    
    y_true = y_true.detach().cpu().numpy() if isinstance(y_true, torch.Tensor) else y_true # label 
    y_pred = y_pred.detach().cpu().numpy() if isinstance(y_pred, torch.Tensor) else y_pred # recon_error 
    
    mse_auc, mse_ap = aucPerformance(y_pred, y_true)
    mse_f1 = F1Performance(y_pred, y_true)
    
    return mse_auc, mse_ap, mse_f1 

def log_results_to_csv(log_file, opt, best_test_auroc, best_test_ap, best_f1):
    import os 
    import csv 
    results = vars(opt) 
    
    results.update({
        'best_test_auroc': best_test_auroc,
        'best_test_aucpr': best_test_ap,
        'best_test_f1': best_f1
    })

    file_exists = os.path.isfile(log_file)

    with open(log_file, mode='a', newline='') as file:
        writer = csv.DictWriter(file, fieldnames=results.keys())
        
        if not file_exists:
            writer.writeheader()
        
        writer.writerow(results)
        