from dhg.data import CoauthorshipDBLP, News20, CoauthorshipCora, IMDB4k, DBLP4k
from dhg import Hypergraph
import numpy as np
import torch


def vertex_mask(num_vertex, idx):
    mask = torch.zeros(num_vertex, dtype=torch.bool)
    mask[idx] = 1
    return mask


def random_split(lbl, num_classes, num_per_class, num_development):
    num_vertex = lbl.shape[0]
    try:
        development_idx = np.random.choice(num_vertex, num_development, replace=False)
        test_idx = [i for i in np.arange(num_vertex) if i not in development_idx]
        train_idx = []
        for _class in range(num_classes):
            class_idx = development_idx[np.where(lbl[development_idx].cpu() == _class)[0]]
            train_idx.extend(np.random.choice(class_idx, num_per_class, replace=False))
        val_idx = [i for i in development_idx if i not in train_idx]
    except:
        # exit(0)
        train_idx = []
        for _class in range(num_classes):
            class_idx = np.where(lbl.cpu() == _class)[0]
            train_idx.extend(np.random.choice(class_idx, min(num_per_class, class_idx.shape[0]), replace=False))
        non_train_idx = [i for i in np.arange(num_vertex) if i not in train_idx]
        val_idx = np.random.choice(non_train_idx, num_development - len(train_idx), replace=False)
        test_idx = [i for i in np.arange(num_vertex) if i not in train_idx and i not in val_idx]

    res_train_mask = vertex_mask(num_vertex, train_idx)
    res_val_mask = vertex_mask(num_vertex, val_idx)
    res_test_mask = vertex_mask(num_vertex, test_idx)
    return res_train_mask, res_val_mask, res_test_mask


def add_hypergraph_self_loop(hg):
    add_hyperedge_list = []
    deg_v = hg.deg_v
    for i in range(hg.num_v):
        if deg_v[i] == 0:
            add_hyperedge_list.append([i])
    hg.add_hyperedges(add_hyperedge_list)
    return hg


def load_data(dataset, split='random', num_per_class=None, num_development=None):
    if dataset == 'Cora-CA':
        data = CoauthorshipCora()
        X, lbl = data["features"], data["labels"]
        HG = Hypergraph(data["num_vertices"], data["edge_list"])
        HG = add_hypergraph_self_loop(HG)
        dim_features = data["dim_features"]
        num_classes = data["num_classes"]
    elif dataset == 'DBLP-CA':
        data = CoauthorshipDBLP()
        X, lbl = data["features"], data["labels"]
        HG = Hypergraph(data["num_vertices"], data["edge_list"])
        dim_features = data["dim_features"]
        num_classes = data["num_classes"]
    elif dataset == 'News20':
        data = News20()
        X, lbl = data["features"], data["labels"]
        HG = Hypergraph(data["num_vertices"], data["edge_list"])
        dim_features = data["dim_features"]
        num_classes = data["num_classes"]
    elif dataset == 'IMDB4k-CA':
        data = IMDB4k()
        X, lbl = data["features"], data["labels"]
        HG = Hypergraph(data["num_vertices"], data["edge_by_actor"])
        dim_features = data["dim_features"]
        num_classes = data["num_classes"]
    elif dataset == 'IMDB4k-CD':
        data = IMDB4k()
        X, lbl = data["features"], data["labels"]
        HG = Hypergraph(data["num_vertices"], data["edge_by_director"])
        dim_features = data["dim_features"]
        num_classes = data["num_classes"]
    elif dataset == 'DBLP4k-CC':
        data = DBLP4k()
        X, lbl = data["features"], data["labels"]
        HG = Hypergraph(data["num_vertices"], data["edge_by_conf"])
        dim_features = data["dim_features"]
        num_classes = data["num_classes"]
    elif dataset == 'DBLP4k-CP':
        data = DBLP4k()
        X, lbl = data["features"], data["labels"]
        HG = Hypergraph(data["num_vertices"], data["edge_by_paper"])
        dim_features = data["dim_features"]
        num_classes = data["num_classes"]
    else:
        print(f"dataset doesn't exist")
        exit(0)

    if split == 'random':
        train_mask, val_mask, test_mask = random_split(lbl, num_classes, num_per_class, num_development)
        return X, HG, lbl, dim_features, num_classes, train_mask, val_mask, test_mask
    else:
        print("Error split")
        exit(0)
