import torch
import os.path as osp
from torch_geometric.datasets import Planetoid, PPI, WikiCS, Coauthor, Amazon, CoraFull
import torch_geometric.transforms as T
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from deeprobust.graph.data import Dataset, PrePtbDataset
import scipy.sparse as sp
import numpy as np
from deeprobust.graph.data import Dataset
from deeprobust.graph.global_attack import NodeEmbeddingAttack
from deeprobust.graph import utils
from deeprobust.graph.utils import get_train_val_test_gcn, get_train_val_test
from torch_geometric.utils import train_test_split_edges
from torch_geometric.utils import add_remaining_self_loops, to_undirected
from ogb.nodeproppred import PygNodePropPredDataset
from sklearn.model_selection import train_test_split
from deeprobust.graph.data.pyg_dataset import Dpr2Pyg
from torch_geometric.utils import subgraph
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score
import subprocess


@torch.no_grad()
def eval_acc(y_true, y_pred):
    acc_list = []
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy()
    return (y_true == y_pred).sum() / y_true.shape[0]


@torch.no_grad()
def eval_rocauc(y_true, y_pred):
    rocauc_list = []
    y_true = y_true.detach().cpu().numpy()
    if y_true.shape[1] == 1:
        # use the predicted class for single-class classification
        y_pred = F.softmax(y_pred, dim=-1)[:,1].unsqueeze(1).cpu().numpy()
    else:
        y_pred = y_pred.detach().cpu().numpy()

    for i in range(y_true.shape[1]):
        # AUC is only defined when there is at least one positive data.
        if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0:
            is_labeled = y_true[:, i] == y_true[:, i]
            score = roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i])
            rocauc_list.append(score)
    if len(rocauc_list) == 0:
        raise RuntimeError(
            'No positively labeled data available. Cannot compute ROC-AUC.')
    return sum(rocauc_list)/len(rocauc_list)


@torch.no_grad()
def eval_f1(y_true, y_pred):
    y_true = y_true.detach().cpu().numpy()
    y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy()
    f1 = f1_score(y_true, y_pred, average='macro')
    return f1


def reset_args(args):
    args.weight_decay = 1e-3
    args.dropout = 0
    if args.dataset in ['amazon-photo']:
        args.lr = 0.001
        args.nlayers = 2
        args.hidden = 32

        if args.model in ['GCN', 'GAT']:
            # BN param adapt
            if args.method_bn in ['BNPA']:
                args.tta_epochs = 40
                args.lr_te = 0.001
            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001
        if args.model in ['SAGE']:
            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 80
                args.lr_te = 0.001

        # BN statistic adapt
        args.bin_num = 100
        args.tta_epochs_bns = 10
        args.loss_lambda = 0.1
        args.lr_mask = 0.01
        args.learn_mask_epochs = 300

    elif args.dataset in ['cora']:
        args.lr = 0.001
        args.nlayers = 2
        args.hidden = 32

        if args.model in ['GCN', 'SAGE', 'GAT']:
            # BN param adapt
            if args.method_bn in ['BNPA']:
                args.tta_epochs = 80
                args.lr_te = 0.001

            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001

        # BN statistic adapt
        args.bin_num = 100
        args.tta_epochs_bns = 10
        args.loss_lambda = 0.1
        args.lr_mask = 0.01
        args.learn_mask_epochs = 300

    elif args.dataset == 'ogb-arxiv':
        if args.ood:
            args.lr = 0.01
            args.nlayers = 2
            args.hidden = 32
            args.weight_decay = 0

            if args.model in ['GCN']:
                # BN param adapt
                if args.method_bn in ['BNPA']:
                    args.tta_epochs = 70
                    args.lr_te = 0.001

                # BNSA+BNPA
                if args.method_bn in ['BNSA+BNPA']:
                    args.tta_epochs = 10
                    args.lr_te = 0.001
            if args.model in ['SAGE']:
                # BNSA+BNPA
                if args.method_bn in ['BNSA+BNPA']:
                    args.tta_epochs = 10
                    args.lr_te = 0.0001
            if args.model in ['GAT']:
                # BNSA+BNPA
                if args.method_bn in ['BNSA+BNPA']:
                    args.tta_epochs = 20
                    args.lr_te = 0.001

            # BN statistic adapt
            args.bin_num = 10
            args.tta_epochs_bns = 1
            args.loss_lambda = 1.5
            args.lr_mask = 0.2
            args.learn_mask_epochs = 300

        else:
            args.lr = 0.01
            args.dropout = 0.5
            args.nlayers = 3
            args.hidden = 256
            args.weight_decay = 0

    elif args.dataset == 'fb100':
        args.lr = 0.01
        args.nlayers = 2
        args.hidden = 32

        if args.model in ['GCN']:
            # BN param adapt
            if args.method_bn in ['BNPA']:
                args.tta_epochs = 30
                args.lr_te = 0.001

            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 30
                args.lr_te = 0.001
        if args.model in ['SAGE', 'GAT']:
            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001

        # BN statistic adapt
        args.bin_num = 10
        args.tta_epochs_bns = 1
        args.loss_lambda = 0.8
        args.lr_mask = 0.1
        if args.model in ['GAT']:
            args.learn_mask_epochs = 10
        else:
            args.learn_mask_epochs = 300

    elif args.dataset == 'twitch-e':
        args.lr = 0.01
        args.nlayers = 2
        args.hidden = 32

        if args.model in ['GCN', 'SAGE', 'GAT']:
            # BN param adapt
            if args.method_bn in ['BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001
            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001

        # BN statistic adapt
        if args.model in ['GCN']:
            args.bin_num = 100
            args.tta_epochs_bns = 10
            args.loss_lambda = 0.1
        elif args.model in ['SAGE', 'GAT']:
            args.bin_num = 10
            args.tta_epochs_bns = 1
            args.loss_lambda = 0.8
        args.learn_mask_epochs = 300
        args.lr_mask = 0.1

    elif args.dataset in ['elliptic']:
        args.lr = 0.01
        args.nlayers = 5
        args.hidden = 32
        args.weight_decay = 0

        if args.model in ['GCN', 'SAGE', 'GAT']:
            # BN param adapt
            if args.method_bn in ['BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001

            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001

        # BN statistic adapt
        if args.model in ['GAT']:
            args.learn_mask_epochs = 100
        else:
            args.learn_mask_epochs = 300
        args.bin_num = 100
        args.tta_epochs_bns = 10
        args.loss_lambda = 0.1
        args.lr_mask = 0.01

    elif args.dataset in ['ogb-products']:
        args.lr = 0.01
        args.nlayers = 5
        args.hidden = 32
        args.weight_decay = 0

        if args.model in ['GCN', 'SAGE', 'GAT']:
            # BN param adapt
            if args.method_bn in ['BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001
            # BNSA+BNPA
            if args.method_bn in ['BNSA+BNPA']:
                args.tta_epochs = 10
                args.lr_te = 0.0001

        # BN statistic adapt
        args.bin_num = 10
        args.tta_epochs_bns = 1
        args.loss_lambda = 0.8
        args.lr_mask = 0.1
        args.learn_mask_epochs = 300
    else:
        raise NotImplementedError


def get_gpu_memory_map():
    result = subprocess.check_output(
        [
            'nvidia-smi', '--query-gpu=memory.used',
            '--format=csv,nounits,noheader'
        ], encoding='utf-8')
    # Convert lines into a dictionary
    gpu_memory = [int(x) for x in result.strip().split('\n')]
    gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
    return gpu_memory_map
