import numpy as np
import torch
import pickle
from scipy import sparse
"""
Copyright 2020 Twitter, Inc.
SPDX-License-Identifier: Apache-2.0
"""
import os
import math
import numpy as np
import torch
from torch_geometric.datasets import Planetoid, MixHopSyntheticDataset, Amazon, WikiCS, WebKB, Actor, Coauthor, CitationFull
import torch_geometric.transforms as transforms
from torch_geometric.utils import to_undirected, add_remaining_self_loops
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj
from torch_geometric.data import Data, InMemoryDataset
from utils import get_symmetrically_normalized_adjacency, get_row_normalized_adjacency, knn_fast
from seeds import development_seed
from utils import get_mask

DATA_PATH = "data"

def get_dataset(name: str, use_lcc: bool = True, homophily=None, initial_filling=None, corr=0):
    path = os.path.join(DATA_PATH, name)
    evaluator = None
    corr = f'_corr_{corr}' if corr > 0.0 else ''

    if name in ['adult', 'drybean', 'adni', 'abide']:
        if initial_filling in ['mode', 'median']:
            with open(f'./data/{name}/processed/data{corr}_{initial_filling}.pkl', 'rb') as f:
                dataset = pickle.load(f)
        else:
            with open(f'./data/{name}/processed/data{corr}.pkl', 'rb') as f:
                dataset = pickle.load(f)
    
    '''
    if name in ["Cora", "CiteSeer", "PubMed"]:
        dataset = Planetoid(path, name)
    elif name in ['cora_full', 'dblp']:
        _name = 'Cora' if name == 'cora_full' else 'DBLP'
        dataset = CitationFull(path, name=_name)
    elif name in ['reddit']:
        from torch_geometric.datasets import Reddit
        dataset = Reddit(path)
    elif name in ["Computers", "Photo"]:
        dataset = Amazon(path, name)
    elif name == "wikics":
        dataset = WikiCS(path)
    elif name in ["cs", "physics"]:
        dataset = Coauthor(path, name)
    elif name in ["texas", "cornell", "wisconsin"]:
        dataset = WebKB(path, name)
    elif name == "actor":
        dataset = Actor(path)
    elif "OGBN" in name: # Arxiv, Products, Mag
        ogb_name = 'ogbn-papers100M' if '100' in name else name.lower()
        dataset = PygNodePropPredDataset(name=ogb_name, transform=transforms.ToSparseTensor(), root=path)
        evaluator = Evaluator(name=ogb_name)
        use_lcc = False
        if name == "OGBN-Mag":
            node_key = 'paper' # 736389 x 128
            edge_index_key = ('paper', 'cites', 'paper') # 2 x 5416271

            dataset.data.x = dataset.data.x_dict[node_key]
            dataset.data.edge_index = dataset.data.edge_index_dict[edge_index_key]
            dataset.data.y = dataset.data.y_dict[node_key]

    elif name == "MixHopSynthetic":
        dataset = MixHopSyntheticDataset(path, homophily=homophily)
    else:
        raise Exception("Unknown dataset.")
    

    if (name not in ["texas", "cornell", "wisconsin", "actor", 'wikics']) & (use_lcc): # wikics takes too long!! please check!!
        dataset = keep_only_largest_connected_component(dataset) # 즉, 모든 노드가 하나의 path 로 연결된 그러한 subgraph 를 사용하겠다는 의미가 됨
        # dataset = dataset
    
    # Make graph undirected so that we have edges for both directions and add self loops
    dataset.data.edge_index = to_undirected(dataset.data.edge_index)
    dataset.data.edge_index, _ = add_remaining_self_loops(dataset.data.edge_index, num_nodes=dataset.data.x.shape[0])
    '''

    return dataset, evaluator

def keep_only_largest_connected_component(dataset):
    lcc = get_largest_connected_component(dataset)

    x_new = dataset.data.x[lcc]
    y_new = dataset.data.y[lcc]

    row, col = dataset.data.edge_index.numpy()
    edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc]
    edges = remap_edges(edges, get_node_mapper(lcc))

    data = Data(
        x=x_new,
        edge_index=torch.LongTensor(edges),
        y=y_new,
        train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
        test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
        val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
    )
    dataset.data = data

    return dataset

def get_component(dataset: InMemoryDataset, start: int = 0) -> set:
    visited_nodes = set()
    queued_nodes = set([start])
    row, col = dataset.data.edge_index.numpy()
    while queued_nodes:
        current_node = queued_nodes.pop()
        visited_nodes.update([current_node])
        neighbors = col[np.where(row == current_node)[0]]
        neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes]
        queued_nodes.update(neighbors)
    return visited_nodes


def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray:
    remaining_nodes = set(range(dataset.data.x.shape[0]))
    comps = []
    while remaining_nodes:
        start = min(remaining_nodes)
        comp = get_component(dataset, start)
        comps.append(comp)
        remaining_nodes = remaining_nodes.difference(comp)
    return np.array(list(comps[np.argmax(list(map(len, comps)))]))


def get_node_mapper(lcc: np.ndarray) -> dict:
    mapper = {}
    counter = 0
    for node in lcc:
        mapper[node] = counter
        counter += 1
    return mapper


def remap_edges(edges: list, mapper: dict) -> list:
    row = [e[0] for e in edges]
    col = [e[1] for e in edges]
    row = list(map(lambda x: mapper[x], row))
    col = list(map(lambda x: mapper[x], col))
    return [row, col]


def set_train_val_test_split(seed: int, data: Data, dataset: str, split_idx: int = None, train_ratio = 0.0) -> Data:
    if dataset in [
        "Cora",
        "CiteSeer",
        "PubMed",
        "Photo",
        "Computers",
        "cs",
        "physics",
        "wikics",
        "dblp"
    ]:
        # Use split from "Diffusion Improves Graph Learning" paper, which selects 20 nodes for each class to be in the training set
        num_val = 5000 if dataset == "cs" else 1500
        if train_ratio > 0.0:
            data = set_uniform_train_val_test_split(seed, data, train_ratio=train_ratio, val_ratio=0.2)
        else:
            data = set_per_class_train_val_test_split(
            seed=seed, data=data, num_val=num_val, num_train_per_class=20, split_idx=split_idx,
            )
    
    elif dataset in ['cora_full']:
        data = set_uniform_train_val_test_split(seed, data, train_ratio=0.1, val_ratio=0.1)
        # data.train_mask = data['train_mask']
        # data.val_mask = data['val_mask']
        # data.test_mask = data['test_mask']
        
    elif "OGBN" in dataset:
        # OGBN datasets have pre-assigned split
        data.train_mask = split_idx["train"]
        data.val_mask = split_idx["valid"]
        data.test_mask = split_idx["test"]
    elif dataset in ["Twitch", "Deezer-Europe", "FB100", "Actor"]:
        # Datasets from "New Benchmarks for Learning on Non-Homophilous Graphs". They use uniform 50/25/25 split
        data = set_uniform_train_val_test_split(seed, data, train_ratio=0.5, val_ratio=0.25)
    elif dataset == "Syn-Cora":
        # Datasets from "Beyond Homophily in Graph Neural Networks: Current Limitations and Effective Designs". They use uniform 25/25/50 split
        data = set_uniform_train_val_test_split(seed, data, train_ratio=0.25, val_ratio=0.25)
    elif dataset == "MixHopSynthetic":
        # Datasets from "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing". They use uniform 33/33/33 split
        data = set_uniform_train_val_test_split(seed, data, train_ratio=0.33, val_ratio=0.33)
    else:
        data = set_uniform_train_val_test_split(seed, data, train_ratio=train_ratio, val_ratio=0.1)

    # else:
        # raise ValueError(f"We don't know how to split the data for {dataset}")

    return data


def set_per_class_train_val_test_split(
    seed: int, data: Data, num_val: int = 1500, num_train_per_class: int = 20, split_idx: int = None,
) -> Data:

    if split_idx is None:
        rnd_state = np.random.RandomState(development_seed)
        num_nodes = data.y.shape[0]
        development_idx = rnd_state.choice(num_nodes, num_val, replace=False)
        test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]

        train_idx = []
        rnd_state = np.random.RandomState(seed)
        for c in range(data.y.max() + 1):
            class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
            train_idx.extend(rnd_state.choice(class_idx, num_train_per_class, replace=False))

        val_idx = [i for i in development_idx if i not in train_idx]

        data.train_mask = get_mask(train_idx, num_nodes)
        data.val_mask = get_mask(val_idx, num_nodes)
        data.test_mask = get_mask(test_idx, num_nodes)

    else:
        data.train_mask = split_idx["train"]
        data.val_mask = split_idx["valid"]
        data.test_mask = split_idx["test"]

    return data


def set_uniform_train_val_test_split(seed: int, data: Data, train_ratio: float = 0.1, val_ratio: float = 0.2) -> Data:
    rnd_state = np.random.RandomState(seed)
    num_nodes = data.y.shape[0]

    # Some nodes have labels -1 (i.e. unlabeled), so we need to exclude them
    labeled_nodes = torch.where(data.y != -1)[0]
    num_labeled_nodes = labeled_nodes.shape[0]
    num_train = math.floor(num_labeled_nodes * train_ratio)
    num_val = math.floor(num_labeled_nodes * val_ratio)

    idxs = list(range(num_labeled_nodes))
    # Shuffle in place
    rnd_state.shuffle(idxs)

    train_idx = idxs[:num_train]
    val_idx = idxs[num_train : num_train + num_val]
    test_idx = idxs[num_train + num_val :]

    train_idx = labeled_nodes[train_idx]
    val_idx = labeled_nodes[val_idx]
    test_idx = labeled_nodes[test_idx]

    data.train_mask = get_mask(train_idx, num_nodes)
    data.val_mask = get_mask(val_idx, num_nodes)
    data.test_mask = get_mask(test_idx, num_nodes)

    # Set labels of unlabeled nodes to 0, otherwise there is an issue in label propagation (which does one-hot encoding of all labels)
    # This labels are not used since these nodes are excluded from all masks, do it doesn't affect any results
    data.y[data.y == -1] = 0

    return data
