import os
import pickle
import numpy as np
import math
import pickle

import torch
import torch_geometric
from torch_geometric.data import download_url
from torch_geometric.datasets import (
    Planetoid,
    MixHopSyntheticDataset,
    Reddit2,
    LINKXDataset,
    WikipediaNetwork,
    Actor,
    WebKB,
    CitationFull,
)
from torch_geometric_signed_directed.data.directed.load_directed_real_data import Telegram
import torch_geometric.transforms as transforms
import torch_scatter
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
from torch_geometric.utils.random import erdos_renyi_graph

from src.directed_heterophilous_graphs import DirectedHeterophilousGraphDataset
from src.data_utils import get_mask
from src.third_party import (
    load_snap_patents_mat,
    even_quantile_labels,
)


def get_dataset(name: str, root_dir: str, homophily=None, undirected=False, self_loops=False, transpose=False):
    path = f"{root_dir}/"
    evaluator = None

    if name in ["chameleon", "squirrel"]:
        dataset = WikipediaNetwork(root=path, name=name, transform=transforms.NormalizeFeatures())
        dataset.data.y = dataset.data.y.unsqueeze(-1)
    elif name in ["ogbn-arxiv"]:
        dataset = PygNodePropPredDataset(name=name, transform=transforms.ToSparseTensor(), root=path)
        evaluator = Evaluator(name=name)
        split_idx = dataset.get_idx_split()
        dataset.data.train_mask = get_mask(split_idx["train"], dataset.data.num_nodes)
        dataset.data.val_mask = get_mask(split_idx["valid"], dataset.data.num_nodes)
        dataset.data.test_mask = get_mask(split_idx["test"], dataset.data.num_nodes)
    elif name in ["directed-roman-empire"]:
        dataset = DirectedHeterophilousGraphDataset(name=name, transform=transforms.NormalizeFeatures(), root=path)
    elif name == "snap-patents":
        dataset = load_snap_patents_mat(n_classes=5, root=path)
    elif name == "arxiv-year":
        # arxiv-year uses the same graph and features as ogbn-arxiv, but with different labels
        dataset = PygNodePropPredDataset(name="ogbn-arxiv", transform=transforms.ToSparseTensor(), root=path)
        evaluator = Evaluator(name="ogbn-arxiv")
        y = even_quantile_labels(dataset.data.node_year.flatten().numpy(), nclasses=5, verbose=False)
        dataset.data.y = torch.as_tensor(y).reshape(-1, 1)
        # Tran, val and test masks are required during preprocessing. Setting them here to dummy values as
        # they are overwritten later for this dataset (see get_dataset_split function below)
        dataset.data.train_mask, dataset.data.val_mask, dataset.data.test_mask = 0, 0, 0
        # Create directory for this dataset
        os.makedirs(os.path.join(path, name.replace("-", "_"), "raw"), exist_ok=True)
    elif name == "syn-dir":
        dataset_path = os.path.join(path, "syn-dir", "processed", "dataset.pkl")
        if os.path.isfile(dataset_path):
            dataset = pickle.load(open(dataset_path, "rb"))
        else:
            os.makedirs(os.path.join(path, "syn-dir", "processed"), exist_ok=True)
            dataset = WikipediaNetwork(root=path, name="chameleon")
            num_nodes = 5000
            dataset.data.edge_index = erdos_renyi_graph(num_nodes=num_nodes, edge_prob=0.001, directed=True)
            row, col = dataset.data.edge_index
            # Generate random numbers between -1 and 1
            x = torch.rand(num_nodes) * 2 - 1
            x_mean_in = torch_scatter.scatter_mean(x[col], row, dim_size=num_nodes)
            x_mean_out = torch_scatter.scatter_mean(x[row], col, dim_size=num_nodes)
            # dataset.data.y = torch.round(torch.rand(num_nodes)).long().unsqueeze(-1)
            dataset.data.y = (x_mean_in > x_mean_out).long().unsqueeze(-1)
            dataset.data.x = x.unsqueeze(dim=-1)
            pickle.dump(dataset, open(dataset_path, "wb"))
    elif name in ["cora_ml", "citeseer_full"]:
        if name == "citeseer_full":
            name = "citeseer"
        dataset = CitationFull(path, name)
    elif name == "telegram":
        dataset = Telegram(path)
    else:
        raise Exception("Unknown dataset.")

    if undirected:
        dataset.data.edge_index = torch_geometric.utils.to_undirected(dataset.data.edge_index)
    if self_loops:
        dataset.data.edge_index, _ = torch_geometric.utils.add_self_loops(dataset.data.edge_index)
    if transpose:
        dataset.data.edge_index = torch.stack([dataset.data.edge_index[1], dataset.data.edge_index[0]])

    return dataset, evaluator


def get_sign_dataset(name, root_dir, split_number, data):
    if data is None:
        data = pickle.load(open(os.path.join(root_dir, name.replace("-", "_"), "sign_dataset.pkl"), "rb"))

    train_mask, val_mask, test_mask = get_dataset_split(name, data, root_dir, split_number)
    xs, y = data["xs"], data["y"]
    train_dataset = SIGNDataset(data={"xs": xs[train_mask], "y": y[train_mask]})
    val_dataset = SIGNDataset(data={"xs": xs[val_mask], "y": y[val_mask]})
    test_dataset = SIGNDataset(data={"xs": xs[test_mask], "y": y[test_mask]})
    num_classes = data["num_classes"]
    num_features = data["num_node_features"]
    evaluator = data["evaluator"]

    return (
        train_dataset,
        val_dataset,
        test_dataset,
        evaluator,
        num_classes,
        num_features,
    )


def get_dataset_split(name, data, root_dir, split_number):
    if name in [
        "snap-patents",
        "chameleon",
        "squirrel",
        "actor",
        "cora",
        "pubmed",
        "citeseer",
        "penn94",
        "texas",
        "wisconsin",
        "cornell",
        "telegram",
        "directed-roman-empire",
        "directed-amazon-ratings",
        "Minesweeper",
        "Tolokers",
        "directed-questions",
    ]:
        return (
            data["train_mask"][:, split_number],
            data["val_mask"][:, split_number],
            data["test_mask"][:, split_number],
        )
    if name in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M", "yelp-chi"]:
        # OGBN datasets have a single pre-assigned split
        return data["train_mask"], data["val_mask"], data["test_mask"]
    if name in ["arxiv-year", "genius"]:
        # Datasets from https://arxiv.org/pdf/2110.14446.pdf have five splits stored
        # in https://github.com/CUAI/Non-Homophily-Large-Scale/tree/82f8f05c5c3ec16bd5b505cc7ad62ab5e09051e6/data/splits
        num_nodes = data["y"].shape[0]
        github_url = f"https://github.com/CUAI/Non-Homophily-Large-Scale/raw/master/data/splits/"
        split_file_name = f"{name}-splits.npy"
        local_dir = os.path.join(root_dir, name.replace("-", "_"), "raw")

        download_url(os.path.join(github_url, split_file_name), local_dir, log=False)
        splits = np.load(os.path.join(local_dir, split_file_name), allow_pickle=True)
        split_idx = splits[split_number % len(splits)]

        train_mask = get_mask(split_idx["train"], num_nodes)
        val_mask = get_mask(split_idx["valid"], num_nodes)
        test_mask = get_mask(split_idx["test"], num_nodes)

        return train_mask, val_mask, test_mask
    elif name in ["syn-dir", "cora_ml", "citeseer_full"]:
        # Datasets from "New Benchmarks for Learning on Non-Homophilous Graphs". They use uniform 50/25/25 split
        return set_uniform_train_val_test_split(split_number, data, train_ratio=0.5, val_ratio=0.25)


def set_uniform_train_val_test_split(seed, data, train_ratio=0.5, val_ratio=0.25):
    rnd_state = np.random.RandomState(seed)
    num_nodes = data.y.shape[0]

    # Some nodes have labels -1 (i.e. unlabeled), so we need to exclude them
    labeled_nodes = torch.where(data.y != -1)[0]
    num_labeled_nodes = labeled_nodes.shape[0]
    num_train = math.floor(num_labeled_nodes * train_ratio)
    num_val = math.floor(num_labeled_nodes * val_ratio)

    idxs = list(range(num_labeled_nodes))
    # Shuffle in place
    rnd_state.shuffle(idxs)

    train_idx = idxs[:num_train]
    val_idx = idxs[num_train : num_train + num_val]
    test_idx = idxs[num_train + num_val :]

    train_idx = labeled_nodes[train_idx]
    val_idx = labeled_nodes[val_idx]
    test_idx = labeled_nodes[test_idx]

    train_mask = get_mask(train_idx, num_nodes)
    val_mask = get_mask(val_idx, num_nodes)
    test_mask = get_mask(test_idx, num_nodes)

    return train_mask, val_mask, test_mask
