
from datasets.data_utils import rand_splits
import torch
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np
import torch.nn.functional as F
from datasets.data_utils import get_rw_adj


def stable_cumsum(arr, rtol=1e-05, atol=1e-08):
    """Use high precision for cumsum and check that final value matches sum
    Parameters
    ----------
    arr : array-like
        To be cumulatively summed as flat
    rtol : float
        Relative tolerance, see ``np.allclose``
    atol : float
        Absolute tolerance, see ``np.allclose``
    """
    out = np.cumsum(arr, dtype=np.float64)
    expected = np.sum(arr, dtype=np.float64)
    if not np.allclose(out[-1], expected, rtol=rtol, atol=atol):
        raise RuntimeError('cumsum was found to be unstable: '
                           'its last element does not correspond to sum')
    return out

def fpr_and_fdr_at_recall(y_true, y_score, recall_level= 0.95, pos_label=None):
    classes = np.unique(y_true)
    if (pos_label is None and
            not (np.array_equal(classes, [0, 1]) or
                     np.array_equal(classes, [-1, 1]) or
                     np.array_equal(classes, [0]) or
                     np.array_equal(classes, [-1]) or
                     np.array_equal(classes, [1]))):
        raise ValueError("Data is not binary and pos_label is not specified")
    elif pos_label is None:
        pos_label = 1.

    # make y_true a boolean vector
    y_true = (y_true == pos_label)

    # sort scores and corresponding truth values
    desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1]
    y_score = y_score[desc_score_indices]
    y_true = y_true[desc_score_indices]

    # y_score typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = np.where(np.diff(y_score))[0]
    threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1]

    # accumulate the true positives with decreasing threshold
    tps = stable_cumsum(y_true)[threshold_idxs]
    fps = 1 + threshold_idxs - tps      # add one because of zero-based indexing

    thresholds = y_score[threshold_idxs]

    recall = tps / tps[-1]

    last_ind = tps.searchsorted(tps[-1])
    sl = slice(last_ind, None, -1)      # [last_ind::-1]
    recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl]

    cutoff = np.argmin(np.abs(recall - recall_level))
    if np.array_equal(classes, [1]):
        return thresholds[cutoff]  # return threshold

    return fps[cutoff] / (np.sum(np.logical_not(y_true))), thresholds[cutoff]



def extract_ood_dataset_info(cfg, dataset_ind, dataset_ood_tr, dataset_ood_te):
    ### print dataset info ###
    num_classes = max(dataset_ind.y.max().item() + 1, dataset_ind.y.shape[1])
    num_features = dataset_ind.x.shape[1]

    print(f"ind dataset {cfg['name']}: all nodes {dataset_ind.num_nodes} | centered nodes {dataset_ind.node_idx.shape[0]} | edges {dataset_ind.edge_index.size(1)} | "
        + f"classes {num_classes} | feats {num_features}")
    print(f"ood tr dataset {cfg['name']}: all nodes {dataset_ood_tr.num_nodes} | centered nodes {dataset_ood_tr.node_idx.shape[0]} | edges {dataset_ood_tr.edge_index.size(1)}")
    if isinstance(dataset_ood_te, list):
        for i, data in enumerate(dataset_ood_te):
            print(f"ood te dataset {i} {cfg['name']}: all nodes {data.num_nodes} | centered nodes {data.node_idx.shape[0]} | edges {data.edge_index.size(1)}")
    else:
        print(f"ood te dataset {cfg['name']}: all nodes {dataset_ood_te.num_nodes} | centered nodes {dataset_ood_te.node_idx.shape[0]} | edges {dataset_ood_te.edge_index.size(1)}")

def cal_max_score(id_preds_l, ood_preds_l):
    id_score, ood_score = id_preds_l[0], ood_preds_l[0]
    for id_preds, ood_preds in zip(id_preds_l[1:], ood_preds_l[1:]):
        id_score += id_preds
        ood_score += ood_preds
    id_score = (torch.max(id_score/len(id_preds_l), dim=1)).values.detach().cpu()
    ood_score = (torch.max(ood_score/len(ood_preds_l), dim=1)).values.detach().cpu()
    return id_score, ood_score
def corr_err_label(y_true, logits):
    y_true = y_true.detach().cpu().numpy()
    y_pred = logits.argmax(dim=-1, keepdim=True).detach().cpu().numpy()
    correct = y_true == y_pred
    return correct.flatten().astype(int)

def err_measure(cor_label, id_max_score, recall_level=0.95):
    id_max_score = np.array(id_max_score[:]).flatten()
    max_auroc = roc_auc_score(cor_label, id_max_score)
    max_aupr_cor = average_precision_score(cor_label, id_max_score)
    max_aupr_err = average_precision_score(1-cor_label, -id_max_score)
    max_fpr, max_threshould = fpr_and_fdr_at_recall(cor_label, id_max_score, recall_level)
    return [max_auroc, max_aupr_cor, max_aupr_err, max_fpr]




# GSPDE utilities
@torch.no_grad()
def misclassification(model, dataset_ind, dataset_ood, criterion, eval_func, cfg, device):
    model.eval()
    acc_res = []
    tms = cfg['samples']
    id_predss, ood_predss = [], []
    edge_index_in, edge_weight_in = get_rw_adj(dataset_ind.edge_index, edge_weight=dataset_ind.edge_attr, norm_dim=1,
                                  fill_value=cfg['self_loop_weight'],
                                  num_nodes=dataset_ind.num_nodes,
                                  dtype=dataset_ind.x.dtype)
    #x_in, edge_index_in = dataset_ind.x.to(device), edge_index_in.to(device)
    edge_index_ood, edge_weight_ood = get_rw_adj(dataset_ood.edge_index, edge_weight=dataset_ood.edge_attr, norm_dim=1,
                                   fill_value=cfg['self_loop_weight'],
                                   num_nodes=dataset_ood.num_nodes,
                                   dtype=dataset_ood.x.dtype)
    edge_index_in, edge_index_ood, edge_weight_in, edge_weight_ood = edge_index_in.to(device), \
        edge_index_ood.to(device), edge_weight_in.to(device), edge_weight_ood.to(device)
    id_idx, ood_idx = dataset_ind.splits['test'], dataset_ood.node_idx
    train_idx, valid_idx, test_idx = dataset_ind.splits['train'], dataset_ind.splits['valid'], dataset_ind.splits['test']
    y = dataset_ind.y.to(device)
    model.encoder.ind_edge_index = edge_index_in
    model.encoder.ood_edge_index = edge_index_ood
    model.encoder.ind_edge_weight = edge_weight_in
    model.encoder.ood_edge_weight = edge_weight_ood
    # st = time.time()
    for t in range(tms):
        flag = True
        id_logits = model(dataset_ind, flag, device)
        id_preds = torch.softmax(id_logits, dim=1)[id_idx]
        id_predss.append(id_preds)
        if t == 0:
            # ACC
            train_score = eval_func(y[train_idx], id_logits[train_idx])
            valid_score = eval_func(y[valid_idx], id_logits[valid_idx])
            test_score = eval_func(y[test_idx], id_logits[test_idx])
            valid_out = F.log_softmax(id_logits[valid_idx], dim=1)
            valid_loss = criterion(valid_out, y[valid_idx].squeeze(1))
            test_correct = corr_err_label(y[test_idx], id_logits[test_idx])
            acc_res = [train_score, valid_score, test_score, valid_loss.item()]
            
        
        flag = False
        ood_logits = model(dataset_ood, flag, device)
        ood_preds = torch.softmax(ood_logits, dim=1)[ood_idx]
        ood_preds = torch.cat([id_preds, ood_preds], dim=0)
        ood_predss.append(ood_preds)
        
    id_max_score, ood_max_score = cal_max_score(id_predss, ood_predss)
    # max_auroc, max_aupr, max_fpr, _ = get_measures(id_max_score, ood_max_score)
    misclass_res = err_measure(test_correct, id_max_score)
    return acc_res, misclass_res


def detection(pos, neg):
    #calculate the minimum detection error
    Y1 = neg
    X1 = pos
    print(f"X1: {X1.shape}, Y1: {Y1.shape}")
    end = np.max([np.max(X1), np.max(Y1)])
    start = np.min([np.min(X1),np.min(Y1)])
    gap = (end- start)/10

    errorBase = 1.0
    for delta in np.arange(start, end, gap):
        tpr = np.sum(np.sum(X1 < delta)) / float(len(X1))
        error2 = np.sum(np.sum(Y1 > delta)) / float(len(Y1))
        errorBase = np.minimum(errorBase, (tpr+error2)/2.0)

    return errorBase

def get_measures(_pos, _neg, recall_level=0.95):
    pos = np.array(_pos[:]).reshape((-1, 1))
    neg = np.array(_neg[:]).reshape((-1, 1))
    examples = np.squeeze(np.vstack((pos, neg)))
    labels = np.zeros(len(examples), dtype=np.int32)
    labels[:len(pos)] += 1
    labels_neg = 1-labels
    n_pos = -pos
    n_neg = -neg
    n_examples = np.squeeze(np.vstack((n_pos, n_neg)))

    print(f"number of examples : {len(examples)}")
    print(f"number of negative examples : {len(neg)}")
    print(f"number of positive examples : {len(pos)}")
    print(f"n_examples : {len(n_examples)}")

    auroc = roc_auc_score(labels, examples)
    aupr = average_precision_score(labels, examples)
    aupr_n = average_precision_score(labels_neg, n_examples)
    detection_err = detection(pos, neg)
    detection_acc = 1-detection_err
    fpr, threshould = fpr_and_fdr_at_recall(labels, examples, recall_level)

    return auroc, aupr, aupr_n, fpr, detection_acc



@torch.no_grad()
def evaluate_detection(model, dataset_ind, dataset_ood, criterion, eval_func, cfg, device):
    model.eval()
    
    if isinstance(dataset_ood, list):
        dataset_ood = dataset_ood[-1] # NOTE: for now just take the last one, this seems to be the behavior for gnn-safe
    
    edge_index_in, edge_weight_in = get_rw_adj(dataset_ind.edge_index, edge_weight=dataset_ind.edge_attr, norm_dim=1,
                                  fill_value=cfg['self_loop_weight'],
                                  num_nodes=dataset_ind.num_nodes,
                                  dtype=dataset_ind.x.dtype)
    edge_index_ood, edge_weight_ood = get_rw_adj(dataset_ood.edge_index, edge_weight=dataset_ood.edge_attr, norm_dim=1,
                                   fill_value=cfg['self_loop_weight'],
                                   num_nodes=dataset_ood.num_nodes,
                                   dtype=dataset_ood.x.dtype)
    edge_index_in, edge_index_ood, edge_weight_in, edge_weight_ood = edge_index_in.to(device), \
        edge_index_ood.to(device), edge_weight_in.to(device), edge_weight_ood.to(device)
    id_idx, ood_idx = dataset_ind.splits['test'], dataset_ood.node_idx
    train_idx, valid_idx, test_idx = dataset_ind.splits['train'], dataset_ind.splits['valid'], dataset_ind.splits['test']
    y = dataset_ind.y.to(device)
    model.encoder.ind_edge_index = edge_index_in
    model.encoder.ood_edge_index = edge_index_ood
    model.encoder.ind_edge_weight = edge_weight_in
    model.encoder.ood_edge_weight = edge_weight_ood
    id_scores, ood_scores = 0,0
    for i in range(cfg['samples']):
        id_logits = model(dataset_ind, True, device)
        train_score = eval_func(y[train_idx], id_logits[train_idx])
        valid_score = eval_func(y[valid_idx], id_logits[valid_idx])
        test_score = eval_func(y[test_idx], id_logits[test_idx])

        if cfg["dataset"] in ('proteins', 'ppi'):
            train_loss = criterion(id_logits[train_idx], y[train_idx].to(torch.float))
            valid_loss = criterion(id_logits[valid_idx], y[valid_idx].to(torch.float))
            test_loss = criterion(id_logits[test_idx], y[test_idx].to(torch.float))
        else:   
            train_out = F.log_softmax(id_logits[train_idx], dim=1)
            train_loss = criterion(train_out, y[train_idx].squeeze(1))
            valid_out = F.log_softmax(id_logits[valid_idx], dim=1)
            valid_loss = criterion(valid_out, y[valid_idx].squeeze(1))
            test_out = F.log_softmax(id_logits[test_idx], dim=1)
            test_loss = criterion(test_out, y[test_idx].squeeze(1))
        
        id_score = torch.logsumexp(id_logits, dim=-1)
        if cfg["propagation"] == True:
            id_score = model.propagation(id_score, edge_index_in, 4, 0.3)
        id_score = id_score[id_idx].detach().cpu()
        id_scores += id_score
        
        ood_logits = model(dataset_ood, False, device)
        ood_score = torch.logsumexp(ood_logits, dim=-1)
        if cfg["propagation"] == True:
            ood_score = model.propagation(ood_score, edge_index_ood, 4, 0.3)
        ood_score = ood_score[ood_idx].detach().cpu()
        ood_scores += ood_score
    id_score = id_scores/cfg['samples']
    ood_score = ood_scores/cfg['samples']
    if len(id_score) and len(ood_score) > 0:
        auroc, aupr_in, aupr_out, fpr, detection_acc = get_measures(id_score, ood_score)
    else:
        auroc, aupr_in, aupr_out, fpr, detection_acc = 0.0, 0.0, 0.0, 0.0, 0.0
    return [train_score, valid_score, test_score], [train_loss, valid_loss, test_loss],\
    [auroc, aupr_in, aupr_out, fpr, detection_acc] 