import numpy as np
import scipy.sparse
import scipy.io
import csv
import json
import torch


class NCDataset(object):
    # adopted from https://github.com/CUAI/Non-Homophily-Benchmarks/blob/main/dataset.py
    def __init__(self, name):
        self.name = name
        self.graph = {}
        self.label = None

    def __getitem__(self, idx):
        assert idx == 0, 'This dataset has only one graph'
        return self.graph, self.label

    def __len__(self):
        return 1

    def __repr__(self):
        return '{}({})'.format(self.__class__.__name__, len(self))


def load_nc_dataset(args):
    datname, sub_datname = args.dataset, args.sub_dataset
    DATAPATH = args.DATAPATH
    if datname in ('chameleon', 'squirrel', 'texas', 'cornell', 'actor', 'cora', 'citeseer', 'pubmed'):
        dataset = load_common_dataset(DATAPATH, datname)
    elif datname == 'twitch-e':
        dataset = load_twitch_dataset(DATAPATH, sub_datname)
    elif datname in ('amazon_ratings', 'minesweeper', 'roman_empire', 'tolokers'):
        dataset = load_criticalLook_dataset(DATAPATH, datname)
    else:
        raise ValueError('Invalid dataname')
    return dataset


def load_common_dataset(DATAPATH, datname):
    res_dic = torch.load(f"{DATAPATH}processed/{datname}_dataDic.pt")
    edge_index = res_dic["edge_index"]
    node_feat, label = res_dic["node_feat"], res_dic["label"]
    num_classes, num_nodes = res_dic["num_classes"], res_dic["num_nodes"]
    dataset = NCDataset(datname)
    dataset.label = label.long()
    dataset.graph = {'edge_index': edge_index.long(),
                     'edge_feat': None,
                     'node_feat': node_feat.float(),
                     'num_nodes': num_nodes}
    return dataset


def load_twitch_dataset(DATAPATH, lang):
    assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset'
    A, label, features = load_twitch(DATAPATH, lang)
    dataset = NCDataset(lang)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    node_feat = torch.tensor(features, dtype=torch.float)
    num_nodes = node_feat.shape[0]
    dataset.graph = {'edge_index': edge_index,
                     'edge_feat': None,
                     'node_feat': node_feat,
                     'num_nodes': num_nodes}
    dataset.label = torch.tensor(label)
    return dataset


def load_criticalLook_dataset(DATAPATH, name):
    # load dataset
    tmp = np.load(f"{DATAPATH}criticalLook_dataset/{name}.npz")
    label = torch.tensor(tmp['node_labels'], dtype=torch.long)
    edge_index = torch.tensor(tmp['edges'], dtype=torch.long).T
    node_feat = torch.tensor(tmp['node_features'], dtype=torch.float)
    num_nodes = node_feat.shape[0]

    # provided random splits (x 10) by authors
    split_dic_ls = []
    train_masks, val_masks, test_masks = tmp["train_masks"], tmp["val_masks"], tmp["test_masks"]
    assert train_masks.shape[0] == val_masks.shape[0]
    assert val_masks.shape[0] == test_masks.shape[0]
    for trn_mat, val_mat, tst_mat in zip(train_masks, val_masks, test_masks):
        trn_idx = np.where(trn_mat)[0]
        val_idx = np.where(val_mat)[0]
        tst_idx = np.where(tst_mat)[0]
        tmp = {"trn_idx": trn_idx, "val_idx": val_idx, "tst_idx": tst_idx}
        split_dic_ls.append(tmp)
    assert len(split_dic_ls) == 10

    # create dataset
    dataset = NCDataset(name)
    dataset.graph = {'edge_index': edge_index,
                     'node_feat': node_feat,
                     'edge_feat': None,
                     'num_nodes': num_nodes,
                     'fixed_split_dic_ls': split_dic_ls}
    dataset.label = label
    return dataset


def load_twitch(DATAPATH, lang):
    assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset'
    filepath = "{}twitch/{}".format(DATAPATH, lang)
    label = []
    node_ids = []
    src = []
    targ = []
    uniq_ids = set()
    with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            node_id = int(row[5])
            # handle FR case of non-unique rows
            if node_id not in uniq_ids:
                uniq_ids.add(node_id)
                label.append(int(row[2] == "True"))
                node_ids.append(int(row[5]))

    node_ids = np.array(node_ids, dtype=np.int)
    with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            src.append(int(row[0]))
            targ.append(int(row[1]))
    with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
        j = json.load(f)
    src = np.array(src)
    targ = np.array(targ)
    label = np.array(label)
    inv_node_ids = {node_id: idx for (idx, node_id) in enumerate(node_ids)}
    reorder_node_ids = np.zeros_like(node_ids)
    for i in range(label.shape[0]):
        reorder_node_ids[i] = inv_node_ids[i]

    n = label.shape[0]
    A = scipy.sparse.csr_matrix((np.ones(len(src)),
                                 (np.array(src), np.array(targ))),
                                shape=(n, n))
    features = np.zeros((n, 3170))
    for node, feats in j.items():
        if int(node) >= n:
            continue
        features[int(node), np.array(feats, dtype=int)] = 1
    features = features[:, np.sum(features, axis=0) != 0]  # remove zero cols
    new_label = label[reorder_node_ids]
    label = new_label

    return A, label, features
