import math
import json
import torch
import os
import time
import random
import numpy as np
from scipy.sparse import csr_matrix, lil_matrix, csgraph
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, classification_report, average_precision_score, precision_recall_curve, auc
from torch_geometric.utils.convert import from_scipy_sparse_matrix

from config import *
import warnings
warnings.filterwarnings('ignore')

def set_seed(seed):
    if seed == 0:
        seed = int(time.time())
    random.seed(seed)
    np.random.seed(seed)
    np.random.RandomState(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True)

    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    #os.environ['PYTHONHASHSEED'] = "0"#str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG']=':4096:8'
    return seed

def load_data(data):
    datadir = os.path.join(DATADIR, data)
    graphs = torch.load(os.path.join(datadir, f'{data}.pt'))
   

    with open(os.path.join(datadir, 'info.json'), 'r') as f:
        info = json.load(f) 

    train_path = os.path.join(datadir, 'train.txt')
    train_index = np.loadtxt(train_path, dtype=np.int64)

    val_path = os.path.join(datadir, 'val.txt')
    val_index = np.loadtxt(val_path, dtype=np.int64)

    test_path = os.path.join(datadir, 'test.txt')
    test_index = np.loadtxt(test_path, dtype=np.int64)
    return graphs, info, train_index, val_index, test_index

def compute_metrics(preds, truths):


    labels = truths.detach().cpu().numpy()
    probs = preds.detach().cpu().numpy()[:, 1]

    AUROC = roc_auc_score(labels, probs)
    AUPRC = average_precision_score(labels, probs)
    k = labels.sum()
    RECK = sum(labels[probs.argsort()[-k:]]) / sum(labels)

    preds = preds.detach().cpu().numpy().argmax(axis=1)

    MF1 = f1_score(labels, preds, average='macro')
    return AUROC, AUPRC, RECK, MF1

def compute_priors(num1, num2, device):
    y_prior = torch.log(torch.tensor([num1+1e-8, num2+1e-8], requires_grad = False)).to(device)
    return y_prior
