import os.path as osp
from pathlib import Path
import urllib.request
import pickle

import numpy as np
import torch

from ogb.nodeproppred import PygNodePropPredDataset
from sklearn.model_selection import ShuffleSplit
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils import add_self_loops
from torch_geometric.datasets import (
    Planetoid, WebKB, Amazon, Coauthor, WikipediaNetwork,
    AttributedGraphDataset, HeterophilousGraphDataset)


def sym_edge_index(edge_index):
    row, col = edge_index
    reversed_edge_index = torch.stack([col, row], dim=0)
    combined_edge_index = torch.cat([edge_index, reversed_edge_index], dim=1)
    unique_edge_index = torch.unique(combined_edge_index, dim=1)
    return unique_edge_index


def download(url, fn, out_path):
    try:
        print('Downloading', url + fn)
        urllib.request.urlretrieve(url + fn, out_path)
        print('Done!')
    except:
        raise Exception(
            '''Download failed! Make sure you have stable Internet connection and enter the right name''')

def load_data(name: str,
              loop: bool = True,
              feat_norm: bool = False,
              n_rand_splits: int = 0,
              seed: int = 42):
    """
    Load name from PyG with random splits.
    """
    root = osp.join(Path.home(), 'data')
    name = name.lower()
    transform = NormalizeFeatures() if feat_norm else None
    if name in ['cornell', 'texas', 'wisconsin']:
        data = WebKB(root=root, name=name,
                     transform=transform)[0]
    elif name in ['cora', 'citeseer', 'pubmed']:
        data = Planetoid(root=root, name=name,
                         transform=transform)[0]
    elif name in ['computers', 'photo']:
        data = Amazon(root=root, name=name,
                      transform=transform)[0]
    elif name in ['cs', 'physics']:
        data = Coauthor(root=root, name=name,
                        transform=transform)[0]
    elif name in ['squirrel', 'chameleon']:
        data = WikipediaNetwork(root=root, name=name,
                                transform=transform)[0]
        url = 'https://github.com/yandex-research/heterophilous-graphs/raw/refs/heads/main/data/'
        fn = name + '_filtered.npz'
        path = osp.join(root, name, fn)
        if not osp.exists(path):
            download(url, fn, path)
        np_data = np.load(path)
        data.x = torch.tensor(np_data['node_features'], dtype=torch.float32)
        data.edge_index = torch.tensor(np_data['edges'], dtype=torch.long).T
        data.edge_index = sym_edge_index(data.edge_index)
        data.y = torch.tensor(np_data['node_labels'])
        data.train_mask = torch.tensor(np_data['train_masks'], dtype=torch.bool).T
        data.val_mask = torch.tensor(np_data['val_masks'], dtype=torch.bool).T
        data.test_mask = torch.tensor(np_data['test_masks'], dtype=torch.bool).T
    elif name in ['blogcatalog', 'flickr']:
        data = AttributedGraphDataset(
            root=root, name=name, transform=transform)[0]
        # Source: opengsl
        def load_obj(file_name):
            with open(file_name, 'rb') as f:
                return pickle.load(f)
        
        url = 'https://github.com/zhao-tong/GAug/raw/master/data/graphs/'
        fn = name + '_tvt_nids.pkl'
        path = osp.join(root, name, fn)
        if not osp.exists(path):
            download(url, fn, path)
        if data.x.layout == torch.sparse_csr:
            data.x = data.x.to_dense().to(torch.float32)
        train_indices, val_indices, test_indices = load_obj(osp.join(root, name, fn))
        data.train_mask = torch.zeros((data.num_nodes, 1), dtype=torch.bool)
        data.val_mask = torch.zeros_like(data.train_mask)
        data.test_mask = torch.zeros_like(data.train_mask)
        data.train_mask[train_indices, 0] = True
        data.val_mask[val_indices, 0] = True
        data.test_mask[test_indices, 0] = True
    elif name in ['roman-empire', 'amazon-ratings', 'minesweeper', 'tolokers', 'questions']:
        data = HeterophilousGraphDataset(
            root=root, name=name, transform=transform)[0]
    elif name in ['ogbn-arxiv']:
        dataset = PygNodePropPredDataset(name=name)
        # torch.serialization.add_safe_globals([DataEdgeAttr])
        split_idx = dataset.get_idx_split()
        train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
        data = dataset[0]
        data.train_mask = torch.zeros((data.num_nodes, 1), dtype=torch.bool)
        data.val_mask = torch.zeros_like(data.train_mask)
        data.test_mask = torch.zeros_like(data.train_mask)
        data.train_mask[train_idx, 0] = True
        data.val_mask[valid_idx, 0] = True
        data.test_mask[test_idx, 0] = True
        data.y = data.y.squeeze()  # remove extra dimension
        return data
    else:
        raise ValueError('Unknown dataset')
    # symmetrize edge_index
    data.edge_index = sym_edge_index(data.edge_index)
    if loop:
        data.edge_index, _ = add_self_loops(data.edge_index)
    
    if not hasattr(data, 'train_mask') and n_rand_splits == 0:
        n_rand_splits = 1
        
    if n_rand_splits == 0:  # use default split
        if data.train_mask.ndim == 1:
            data.train_mask = data.train_mask.unsqueeze(1)
            data.val_mask = data.val_mask.unsqueeze(1)
            data.test_mask = data.test_mask.unsqueeze(1)
        return data

    # 60-20-20 split
    train_percentage = 0.6
    val_percentage = 0.2

    train_mask = torch.zeros((data.num_nodes, n_rand_splits), dtype=torch.bool)
    val_mask = torch.zeros_like(train_mask)
    test_mask = torch.zeros_like(train_mask)

    rs = ShuffleSplit(n_splits=n_rand_splits,
                        train_size=train_percentage + val_percentage,
                        random_state=seed)
    for i , (train_and_val_index, test_index) in enumerate(rs.split(data.x)):
        train_index, val_index = next(ShuffleSplit(
            n_splits=1, train_size=train_percentage, random_state=seed).split(
            data.x[train_and_val_index]))

        train_index = train_and_val_index[train_index]
        val_index = train_and_val_index[val_index]

        train_mask[train_index, i] = True
        val_mask[val_index, i] = True
        test_mask[test_index, i] = True

    data.train_mask = train_mask
    data.val_mask = val_mask
    data.test_mask = test_mask
    return data
