import os
import requests
import types
import json
import csv
import pickle

import numpy as np
from sklearn.preprocessing import label_binarize
import scipy.io

import torch
from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid, Amazon, Coauthor, DeezerEurope, Actor, MixHopSyntheticDataset
import torch_geometric.transforms as transforms
from torch_geometric.utils import to_undirected, add_remaining_self_loops
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

from data_utils import keep_only_largest_connected_component

DATA_PATH = 'data'

def get_dataset(name: str, use_lcc: bool = True, homophily=None):
    path = os.path.join(DATA_PATH, name)
    evaluator = None

    if name in ['Cora', 'CiteSeer', 'PubMed']:
        dataset = Planetoid(path, name)
    elif name in ['Computers', 'Photo']:
        dataset = Amazon(path, name)
    elif name == 'CoauthorCS':
        dataset = Coauthor(path, 'CS')
    elif name == 'CoauthorPhysics':
        dataset = Coauthor(path, 'Physics')
    elif name in ['OGBN-Arxiv', 'OGBN-Products']:
        dataset = PygNodePropPredDataset(name=name.lower(), transform=transforms.ToSparseTensor(), root=path)
        evaluator = Evaluator(name=name.lower())
        use_lcc = False
    elif name == "Twitch":
        dataset = load_twitch_dataset("DE")
        use_lcc = False
    elif name == "Deezer-Europe":
        dataset = DeezerEurope(path)
        use_lcc = False
    elif name == "FB100":
        sub_dataname = 'Penn94'
        dataset = load_fb100_dataset(sub_dataname)
        use_lcc = False
    elif name == "Actor":
        dataset = Actor(path)
        use_lcc = False
    elif name == 'Syn-Cora':
        dataset = load_syn_cora(homophily)
    elif name == 'MixHopSynthetic':
        dataset = MixHopSyntheticDataset(path, homophily=homophily)
    else:
        raise Exception('Unknown dataset.')

    if use_lcc:
        dataset = keep_only_largest_connected_component(dataset)
    
    # Make graph undirected so that we have edges for both directions and add self loops
    dataset.data.edge_index = to_undirected(dataset.data.edge_index)
    dataset.data.edge_index, _ = add_remaining_self_loops(dataset.data.edge_index, num_nodes=dataset.data.x.shape[0])
    print("Data: ", dataset.data)    
    
    return dataset, evaluator

def load_fb100(filename):
    if not os.path.exists(f"{DATA_PATH}/FB100/"):
        os.mkdir(f"{DATA_PATH}/FB100/")

    if not os.path.isfile(f"{DATA_PATH}/FB100/{filename}"):
        url = f"https://github.com/CUAI/Non-Homophily-Benchmarks/raw/5b2ffa908274f9929b95402b71c9b645928f292c/data/facebook100/{filename}.mat"
        r = requests.get(url, allow_redirects=True)
        with open(f"{DATA_PATH}/FB100/{filename}.mat", "wb") as f:
            f.write(r.content)

    mat = scipy.io.loadmat(DATA_PATH + '/FB100/' + filename + '.mat')
    A = mat['A']
    metadata = mat['local_info']
    return A, metadata

def load_fb100_dataset(filename):
    A, metadata = load_fb100(filename)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    metadata = metadata.astype(np.int)
    label = metadata[:, 1] - 1  # gender label, -1 means unlabeled

    # make features into one-hot encodings
    feature_vals = np.hstack(
        (np.expand_dims(metadata[:, 0], 1), metadata[:, 2:]))
    features = np.empty((A.shape[0], 0))
    for col in range(feature_vals.shape[1]):
        feat_col = feature_vals[:, col]
        feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col))
        features = np.hstack((features, feat_onehot))

    node_feat = torch.tensor(features, dtype=torch.float)

    data = Data(
            x=node_feat,
            edge_index=edge_index,
            y=torch.tensor(label),
        )

    # This allows to just have a general object to which we can assign fields
    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1

    return dataset

def load_twitch_dataset(lang):
    A, label, features = load_twitch(lang)
    edge_index = torch.tensor(A.nonzero(), dtype=torch.long)
    node_feat = torch.tensor(features, dtype=torch.float)
    
    data = Data(
            x=node_feat,
            edge_index=edge_index,
            y=torch.tensor(label),
        )

    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1

    return dataset

def load_twitch(lang):
    assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset'

    if not os.path.exists(f"{DATA_PATH}/Twitch/"):
        os.mkdir(f"{DATA_PATH}/Twitch/")

    files = ["musae_DE_target.csv", "musae_DE_edges.csv", "musae_DE_features.json"]

    for file in files:
        if not os.path.isfile(f"{DATA_PATH}/Twitch/{file}"):
            url = f"https://github.com/CUAI/Non-Homophily-Benchmarks/raw/5b2ffa908274f9929b95402b71c9b645928f292c/data/twitch/DE/{file}"
            r = requests.get(url, allow_redirects=True)
            with open(f"{DATA_PATH}/Twitch/{file}", "wb") as f:
                f.write(r.content)

    label = []
    node_ids = []
    src = []
    targ = []
    uniq_ids = set()
    with open(f"{DATA_PATH}/Twitch/musae_{lang}_target.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            node_id = int(row[5])
            # handle FR case of non-unique rows
            if node_id not in uniq_ids:
                uniq_ids.add(node_id)
                label.append(int(row[2]=="True"))
                node_ids.append(int(row[5]))

    node_ids = np.array(node_ids, dtype=np.int)
    with open(f"{DATA_PATH}/Twitch/musae_{lang}_edges.csv", 'r') as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            src.append(int(row[0]))
            targ.append(int(row[1]))
    with open(f"{DATA_PATH}/Twitch/musae_{lang}_features.json", 'r') as f:
        j = json.load(f)
    src = np.array(src)
    targ = np.array(targ)
    label = np.array(label)
    inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
    reorder_node_ids = np.zeros_like(node_ids)
    for i in range(label.shape[0]):
        reorder_node_ids[i] = inv_node_ids[i]
    
    n = label.shape[0]
    A = scipy.sparse.csr_matrix((np.ones(len(src)), 
                                 (np.array(src), np.array(targ))),
                                shape=(n,n))
    features = np.zeros((n,3170))
    for node, feats in j.items():
        if int(node) >= n:
            continue
        features[int(node), np.array(feats, dtype=int)] = 1
    features = features[:, np.sum(features, axis=0) != 0] # remove zero cols
    new_label = label[reorder_node_ids]
    label = new_label
    
    return A, label, features

def load_syn_cora(homophily):
    if homophily is None:
        raise ValueError('Specify a level of homophily.')

    data = pickle.load(open(f"data/syn-cora/{homophily}-0.p", "rb"))
    dataset = types.SimpleNamespace()
    dataset.data = data
    dataset.num_classes = data.y.max().item() + 1
    return dataset