import torch
import numpy as np
import os
import random
import sys
import logging
import dgl
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, confusion_matrix, precision_recall_fscore_support


def check_input(y_true, y_pred):
    '''
        y_true: numpy ndarray or torch tensor of shape (num_node)
        y_pred: numpy ndarray or torch tensor of shape (num_node, num_tasks)
    '''

    # converting to torch.Tensor to numpy on cpu
    if torch is not None and isinstance(y_true, torch.Tensor):
        y_true = y_true.detach().cpu().numpy()

    if torch is not None and isinstance(y_pred, torch.Tensor):
        y_pred = y_pred.detach().cpu().numpy()

    ## check type
    if not (isinstance(y_true, np.ndarray) and isinstance(y_true, np.ndarray)):
        raise RuntimeError('Arguments to Evaluator need to be either numpy ndarray or torch tensor')

    if not y_pred.ndim == 2:
        raise RuntimeError('y_pred must to 2-dim arrray, {}-dim array given'.format(y_true.ndim))

    return y_true, y_pred


def evaluate(y_true, y_pred, all=False):
    y_true, y_pred = check_input(y_true, y_pred)
    if len(y_true.shape) == 1:
        acc = float(np.sum(y_true == y_pred.argmax(axis=-1))) / len(y_true)
    else:
        acc = roc_auc_score(y_true, y_pred, multi_class='ovr')

    if not all:
        return acc
    else:
        auroc = roc_auc_score(y_true, y_pred, multi_class='ovr')
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred.argmax(axis=-1), average='macro')
        gmean = (precision * recall) ** 0.5
        # acc = float(np.sum(y_true == y_pred.argmax(axis=-1))) / len(y_true)
        return {
            'auroc': auroc,
            'f1': f1,
            'gmean': gmean,
            'acc': acc
        }


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    dgl.seed(seed)


def prepare_folder(name, model_name, ratio=None):
    model_dir = f'./model_files/{name}/{model_name}/'
    if ratio is not None:
        model_dir = f'./model_files/{name}/{model_name}/{int(ratio * 10)}'
   
    # if os.path.exists(model_dir):
    #     shutil.rmtree(model_dir)
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    return model_dir


def init_logging(log_root, models_root=None):
    log_root.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s %(message)s")
    if models_root is not None:
        handler_file = logging.FileHandler(
            os.path.join(models_root, "training.log"))
        handler_file.setFormatter(formatter)
        log_root.addHandler(handler_file)
    handler_stream = logging.StreamHandler(sys.stdout)
    handler_stream.setFormatter(formatter)
    log_root.addHandler(handler_stream)


def add_edge_noise(graph, p):
    edge_index = torch.stack(graph.edges())
    num_nodes = graph.num_nodes()
    edge_set = set(map(tuple, edge_index.transpose(0, 1).tolist()))
    num_of_new_edge = int((edge_index.size(1) // 2) * p)
    to_add_u = list()
    to_add_v = list()
    new_edges = random.sample(range(1, num_nodes**2 + 1), num_of_new_edge + len(edge_set) + num_nodes)
    c = 0
    for i in new_edges:
        if c >= num_of_new_edge:
            break
        s = ((i - 1) // num_nodes) + 1
        t = i - (s - 1) * num_nodes
        s -= 1
        t -= 1
        if s != t and (s, t) not in edge_set:
            c += 1
            to_add_u += [s, t]
            to_add_v += [t, s]
            edge_set.add((s, t))
            edge_set.add((t, s))
    logging.info(f"num of added edges: {len(to_add_u)}")
    graph.add_edges(torch.tensor(to_add_u, dtype=torch.long, device=graph.device), torch.tensor(to_add_v, dtype=torch.long, device=graph.device))
    return graph
