import subprocess
from collections import defaultdict

import numpy as np
import torch
import torch.nn.functional as F
from scipy import sparse as sp
from torch_sparse import SparseTensor


def rand_train_test_idx(
    label, train_prop=0.5, valid_prop=0.25, ignore_negative=True
):
    if ignore_negative:
        labeled_nodes = torch.where(label != -1)[0]
    else:
        labeled_nodes = label

    n = labeled_nodes.shape[0]
    train_num = int(n * train_prop)
    valid_num = int(n * valid_prop)
    perm = torch.as_tensor(np.random.permutation(n))
    train_indices = perm[:train_num]
    val_indices = perm[train_num : train_num + valid_num]
    test_indices = perm[train_num + valid_num :]
    if not ignore_negative:
        return train_indices, val_indices, test_indices
    train_idx = labeled_nodes[train_indices]
    valid_idx = labeled_nodes[val_indices]
    test_idx = labeled_nodes[test_indices]
    return train_idx, valid_idx, test_idx


def load_fixed_splits(data_dir, dataset, name, protocol):
    splits_lst = []
    if name in ["cora", "citeseer", "pubmed"] and protocol == "semi":
        splits = {}
        splits["train"] = torch.as_tensor(dataset.train_idx)
        splits["valid"] = torch.as_tensor(dataset.valid_idx)
        splits["test"] = torch.as_tensor(dataset.test_idx)
        splits_lst.append(splits)
    elif name in [
        "cora",
        "citeseer",
        "pubmed",
        "chameleon",
        "squirrel",
        "film",
        "cornell",
        "texas",
        "wisconsin",
    ]:
        for i in range(10):
            splits_file_path = (
                "{}/geom-gcn/splits/{}".format(data_dir, name)
                + "_split_0.6_0.2_"
                + str(i)
                + ".npz"
            )
            splits = {}
            with np.load(splits_file_path) as splits_file:
                splits["train"] = torch.BoolTensor(splits_file["train_mask"])
                splits["valid"] = torch.BoolTensor(splits_file["val_mask"])
                splits["test"] = torch.BoolTensor(splits_file["test_mask"])
            splits_lst.append(splits)
    else:
        raise NotImplementedError
    return splits_lst


def class_rand_splits(label, label_num_per_class):
    train_idx, non_train_idx = [], []
    idx = torch.arange(label.shape[0])
    class_list = label.squeeze().unique()
    valid_num, test_num = 500, 1000
    for i in range(class_list.shape[0]):
        c_i = class_list[i]
        idx_i = idx[label.squeeze() == c_i]
        n_i = idx_i.shape[0]
        rand_idx = idx_i[torch.randperm(n_i)]
        train_idx += rand_idx[:label_num_per_class].tolist()
        non_train_idx += rand_idx[label_num_per_class:].tolist()
    train_idx = torch.as_tensor(train_idx)
    non_train_idx = torch.as_tensor(non_train_idx)
    non_train_idx = non_train_idx[torch.randperm(non_train_idx.shape[0])]
    valid_idx, test_idx = (
        non_train_idx[:valid_num],
        non_train_idx[valid_num : valid_num + test_num],
    )

    return train_idx, valid_idx, test_idx


def even_quantile_labels(vals, nclasses, verbose=True):
    label = -1 * np.ones(vals.shape[0], dtype=np.int)
    interval_lst = []
    lower = -np.inf
    for k in range(nclasses - 1):
        upper = np.quantile(vals, (k + 1) / nclasses)
        interval_lst.append((lower, upper))
        inds = (vals >= lower) * (vals < upper)
        label[inds] = k
        lower = upper
    label[vals >= lower] = nclasses - 1
    interval_lst.append((lower, np.inf))
    if verbose:
        print("Class Label Intervals:")
        for class_idx, interval in enumerate(interval_lst):
            print(f"Class {class_idx}: [{interval[0]}, {interval[1]})]")
    return label


def to_planetoid(dataset):
    split_idx = dataset.get_idx_split("random", 0.25)
    train_idx, valid_idx, test_idx = (
        split_idx["train"],
        split_idx["valid"],
        split_idx["test"],
    )
    graph, label = dataset[0]
    label = torch.squeeze(label)
    print("generate x")
    x = graph["node_feat"][train_idx].numpy()
    x = sp.csr_matrix(x)
    tx = graph["node_feat"][test_idx].numpy()
    tx = sp.csr_matrix(tx)
    allx = graph["node_feat"].numpy()
    allx = sp.csr_matrix(allx)
    y = F.one_hot(label[train_idx]).numpy()
    ty = F.one_hot(label[test_idx]).numpy()
    ally = F.one_hot(label).numpy()
    edge_index = graph["edge_index"].T
    graph = defaultdict(list)
    for i in range(0, label.shape[0]):
        graph[i].append(i)
    for start_edge, end_edge in edge_index:
        graph[start_edge.item()].append(end_edge.item())
    return x, tx, allx, y, ty, ally, graph, split_idx


def to_sparse_tensor(edge_index, edge_feat, num_nodes):
    num_edges = edge_index.size(1)
    (row, col), N, E = edge_index, num_nodes, num_edges
    perm = (col * N + row).argsort()
    row, col = row[perm], col[perm]
    value = edge_feat[perm]
    adj_t = SparseTensor(
        row=col, col=row, value=value, sparse_sizes=(N, N), is_sorted=True
    )
    adj_t.storage.rowptr()
    adj_t.storage.csr2csc()
    return adj_t


def normalize(edge_index):
    adj_t = edge_index.set_diag()
    deg = adj_t.sum(dim=1).to(torch.float)
    deg_inv_sqrt = deg.pow(-0.5)
    deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0
    adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1)
    return adj_t


def gen_normalized_adjs(dataset):
    row, col = dataset.graph["edge_index"]
    N = dataset.graph["num_nodes"]
    adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N))
    deg = adj.sum(dim=1).to(torch.float)
    D_isqrt = deg.pow(-0.5)
    D_isqrt[D_isqrt == float("inf")] = 0
    DAD = D_isqrt.view(-1, 1) * adj * D_isqrt.view(1, -1)
    DA = D_isqrt.view(-1, 1) * D_isqrt.view(-1, 1) * adj
    AD = adj * D_isqrt.view(1, -1) * D_isqrt.view(1, -1)
    return DAD, DA, AD


def convert_to_adj(edge_index, n_node):
    adj = torch.zeros((n_node, n_node))
    row, col = edge_index
    adj[row, col] = 1
    return adj


def adj_mul(adj_i, adj, N):
    adj_i_sp = torch.sparse_coo_tensor(
        adj_i,
        torch.ones(adj_i.shape[1], dtype=torch.float).to(adj.device),
        (N, N),
    )
    adj_sp = torch.sparse_coo_tensor(
        adj, torch.ones(adj.shape[1], dtype=torch.float).to(adj.device), (N, N)
    )
    adj_j = torch.sparse.mm(adj_i_sp, adj_sp)
    adj_j = adj_j.coalesce().indices()
    return adj_j


def get_gpu_memory_map():
    result = subprocess.check_output(
        [
            "nvidia-smi",
            "--query-gpu=memory.used",
            "--format=csv,nounits,noheader",
        ],
        encoding="utf-8",
    )
    gpu_memory = np.array([int(x) for x in result.strip().split("\n")])
    return gpu_memory


dataset_drive_url = {
    "snap-patents": "1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia",
    "pokec": "1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y",
    "yelp-chi": "1fAXtTVQS4CfEk4asqrFw9EPmlUPGbGtJ",
}

splits_drive_url = {
    "snap-patents": "12xbBRqd8mtG_XkNLH8dRRNZJvVM4Pw-N",
    "pokec": "1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_",
}
