import os
import os.path as osp
from typing import Tuple

import numpy as np
import pandas as pd
import scipy
import torch
import torch_geometric.transforms as T
import torchmetrics
from sklearn.preprocessing import StandardScaler
from torch_geometric.data import Data, Batch, download_url
from torch_geometric.datasets import (Planetoid, WikiCS, Coauthor, Amazon,
                                      GNNBenchmarkDataset, Yelp, Flickr,
                                      Reddit2, PPI, WebKB, WikipediaNetwork)
import gdown

import sklearn
from ogb.nodeproppred import PygNodePropPredDataset, NodePropPredDataset
from torch_geometric.transforms import BaseTransform

from utils.load_data import load_twitch_gamer, load_facebook100
from utils.training import random_planetoid_splits, random_splits, even_quantile_labels
from utils.utils import dropout_edge


SPLIT_RUNS = 10


class StandardizeFeatures(BaseTransform):
    def __call__(self, data: Data) -> Data:
        scaler = StandardScaler()
        scaler.fit(data.x)
        X = scaler.transform(data.x)
        data.x = torch.from_numpy(X).type(data.x.dtype)
        return data


def load_split(path: str, data, train_rate: float, val_rate: float,
               num_classes=-1, filtered=False):
    file = osp.join(path, f'split_{train_rate}_{val_rate}{"_filtered" if filtered else ""}.pt')

    if osp.exists(file):
        split_dict = torch.load(file)
    else:
        train_masks, val_masks, test_masks = [], [], []
        for i in range(SPLIT_RUNS):
            if num_classes > 0:
                train_mask, val_mask, test_mask = random_planetoid_splits(
                    data, num_classes, train_rate, val_rate)
            else:
                train_mask, val_mask, test_mask = random_splits(
                    data, train_rate, val_rate)
            train_masks.append(train_mask)
            val_masks.append(val_mask)
            test_masks.append(test_mask)
        split_dict = {
            'train_mask': torch.stack(train_masks, dim=1),
            'val_mask': torch.stack(val_masks, dim=1),
            'test_mask': torch.stack(test_masks, dim=1)
        }
        torch.save(split_dict, file)

    data.train_mask = split_dict['train_mask']
    data.val_mask = split_dict['val_mask']
    data.test_mask = split_dict['test_mask']
    return data


def load_ptb_edges(path, edge_index, num_nodes, ptb_type, ptb_ratio, **kwargs):
    if ptb_ratio <= 0.:
        return edge_index

    file = osp.join(path, f'perturbed_{ptb_type}-{ptb_ratio}_edges.pt')
    if osp.exists(file):
        new_edge_index = torch.load(file)
        return new_edge_index

    num_ptb_edges = int(ptb_ratio * edge_index.size(1))
    if ptb_type == 'rand':
        ptb_edges = torch.stack([
            torch.randint(low=0, high=num_nodes, size=(num_ptb_edges,)),
            torch.randint(low=0, high=num_nodes, size=(num_ptb_edges,)),
        ], dim=0)
        new_edge_index = torch.cat([edge_index, ptb_edges], dim=1)
    elif ptb_type == 'drop':
        new_edge_index = dropout_edge(edge_index, p=ptb_ratio)
    else:
        raise NotImplementedError

    torch.save(new_edge_index, file)
    return new_edge_index


def get_ogbn(root: str, name: str, **kwargs):
    pre_transform = T.Compose([T.ToSparseTensor()])
    transform = T.Compose([StandardizeFeatures()])
    name = 'ogbn-' + name
    dataset = PygNodePropPredDataset(name, root, transform=transform, pre_transform=pre_transform)
    split = dataset.get_idx_split()
    n_data = Data(train_mask=split['train'], val_mask=split['valid'], test_mask=split['test'])
    data = dataset[0].update(n_data)
    data.adj_t = data.adj_t.to_symmetric()  # very important
    data.y = data.y.squeeze()
    return data, dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_arxiv_year(root, train_rate=0.5, val_rate=0.25, nclass=5, **kwargs):
    pre_transform = T.Compose([T.ToSparseTensor()])
    # transform = T.Compose([StandardizeFeatures()])
    transform = T.Compose([])
    dataset = PygNodePropPredDataset('ogbn-arxiv', root, transform=transform, pre_transform=pre_transform)
    split = dataset.get_idx_split()
    n_data = Data(train_mask=split['train'], val_mask=split['valid'], test_mask=split['test'])
    data = dataset[0].update(n_data)

    if kwargs.pop('rev_adj', False):
        data.adj_t = data.adj_t.t()

    # data.adj_t = data.adj_t.to_symmetric()  # very important for year to have direct edges?

    label = even_quantile_labels(
        data.node_year.numpy().flatten(), nclass, verbose=False)
    data.y = torch.as_tensor(label, dtype=torch.long)

    proc_dir = f'{root}/ogbn_arxiv/year'
    os.makedirs(proc_dir, exist_ok=True)
    data = load_split(proc_dir, data, train_rate, val_rate)

    return data, dataset.num_features, nclass, proc_dir


def get_planetoid(root: str, name: str, split: str = 'public', **kwargs) -> Tuple[Data, int, int, str]:
    transform = T.Compose([T.NormalizeFeatures()])
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = Planetoid(f'{root}/Planetoid', name, split, transform=transform, pre_transform=pre_transform)

    data, num_classes = dataset[0], dataset.num_classes

    if kwargs.get('train_rate', False):
        data = load_split(f'{root}/Planetoid/{name}', data, kwargs['train_rate'], kwargs['val_rate'],
                          num_classes=num_classes)

    return data, dataset.num_features, num_classes, dataset.processed_dir


def get_webkb(root: str, name: str) -> Tuple[Data, int, int, str]:
    transform = T.Compose([T.NormalizeFeatures()])
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = WebKB(f'{root}/WebKB', name, transform=transform, pre_transform=pre_transform)
    data = dataset[0]
    split = 0
    data.train_mask = data.train_mask[:, split]
    data.val_mask = data.val_mask[:, split]
    data.test_mask = data.test_mask[:, split]
    return data, dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_amazon(root: str, name: str) -> Tuple[Data, int, int, str]:
    transform = T.Compose([T.NormalizeFeatures()])
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = Amazon(f'{root}/Amazon', name, transform=transform, pre_transform=pre_transform)
    return dataset[0], dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_wikinet(root: str, name: str, train_rate=None, val_rate=None, filtered=False, **kwargs) -> Tuple[Data, int, int, str]:
    # follow GPRGNN's setting
    if filtered:
        url = f'https://github.com/yandex-research/heterophilous-graphs/raw/main/data/{name}_filtered.npz'
        raw_dir = f'{root}/WikipediaNetwork/{name}'
        download_url(url, raw_dir)

        ndata = np.load(osp.join(raw_dir, f'{name}_filtered.npz'))
        data = Data(
            x=torch.tensor(ndata['node_features']),
            y=torch.tensor(ndata['node_labels']),
            edge_index=torch.tensor(ndata['edges']).t(),
            train_mask=torch.tensor(ndata['train_masks']).t(),
            val_mask=torch.tensor(ndata['val_masks']).t(),
            test_mask=torch.tensor(ndata['test_masks']).t(),
        )
        # transform = T.Compose([T.ToUndirected(), T.ToSparseTensor()])
        transform = T.Compose([T.ToSparseTensor()])  # directed graph
        data = transform(data)

        num_features, num_classes = data.x.shape[1], int(data.y.max()+1)
    else:
        transform = T.Compose([])
        pre_transform = T.Compose([T.ToSparseTensor()])

        preProcDs = WikipediaNetwork(f'{root}/WikipediaNetwork', name,
                                     geom_gcn_preprocess=False, transform=transform, pre_transform=pre_transform)
        dataset = WikipediaNetwork(f'{root}/WikipediaNetwork', name,
                                   geom_gcn_preprocess=True, transform=transform, pre_transform=pre_transform)
        data = dataset[0]
        data.adj_t = preProcDs[0].adj_t

        num_features, num_classes = dataset.num_features, dataset.num_classes

    if train_rate:
        data = load_split(f'{root}/WikipediaNetwork/{name}', data,
                          train_rate, val_rate, num_classes, filtered)

    return data, num_features, num_classes, f'{root}/WikipediaNetwork/{name}/processed'


def get_heterophilious(root: str, name: str, train_rate=None, val_rate=None, **kwargs):
    url = f'https://github.com/yandex-research/heterophilous-graphs/raw/main/data/{name}.npz'
    raw_dir = f'{root}/Heterophilious/{name}'
    download_url(url, raw_dir)

    ndata = np.load(osp.join(raw_dir, f'{name}.npz'))
    data = Data(
        x=torch.tensor(ndata['node_features']),
        y=torch.tensor(ndata['node_labels']),
        edge_index=torch.tensor(ndata['edges']).t(),
        train_mask=torch.tensor(ndata['train_masks']).t(),
        val_mask=torch.tensor(ndata['val_masks']).t(),
        test_mask=torch.tensor(ndata['test_masks']).t(),
    )

    transform_list = [T.ToSparseTensor()]
    if kwargs.pop('undirected', True):
        transform_list.insert(0, T.ToUndirected())
    transform = T.Compose(transform_list)
    data = transform(data)

    num_features, num_classes = data.x.shape[1], len(data.y.unique())

    if train_rate:
        data = load_split(raw_dir, data, train_rate, val_rate,
                          num_classes=num_classes)

    if num_classes == 2:  # binary
        num_classes = 1

    return data, num_features, num_classes, f'{raw_dir}/processed'


def get_wikics(root: str, **kwargs) -> Tuple[Data, int, int, str]:
    transform = T.Compose([])
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = WikiCS(f'{root}/WIKICS', transform=transform, pre_transform=pre_transform)
    data = dataset[0]
    data.adj_t = data.adj_t.to_symmetric()
    data.val_mask = data.stopping_mask
    data.stopping_mask = None
    return data, dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_yelp(root: str, **kwargs) -> Tuple[Data, int, int, str]:
    transform = T.Compose([])
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = Yelp(f'{root}/YELP', transform=transform, pre_transform=pre_transform)
    data = dataset[0]
    data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
    return data, dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_flickr(root: str, **kwargs) -> Tuple[Data, int, int, str]:
    transform = T.Compose([])
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = Flickr(f'{root}/Flickr', transform=transform, pre_transform=pre_transform)
    return dataset[0], dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_reddit(root: str, **kwargs) -> Tuple[Data, int, int, str]:
    transform = T.Compose([])
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = Reddit2(f'{root}/Reddit2', transform=transform, pre_transform=pre_transform)
    data = dataset[0]
    data.x = (data.x - data.x.mean(dim=0)) / data.x.std(dim=0)
    return data, dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_ppi(root: str, split: str = 'train', **kwargs):
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = PPI(f'{root}/PPI', split=split, pre_transform=pre_transform)
    data = Batch.from_data_list(dataset)
    data.batch = None
    data.ptr = None
    data[f'{split}_mask'] = torch.ones(data.num_nodes, dtype=torch.bool)
    return data, dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_sbm(root: str, name: str, **kwargs):
    pre_transform = T.Compose([T.ToSparseTensor()])
    dataset = GNNBenchmarkDataset(f'{root}/SBM', name, split='train',
                                  pre_transform=pre_transform)
    data = Batch.from_data_list(dataset)
    data.batch = None
    data.ptr = None
    return data, dataset.num_features, dataset.num_classes, dataset.processed_dir


def get_facebook100(root: str, name: str, train_rate=0.5, val_rate=0.25,
                    ptb_type=None, ptb_ratio=0., **kwargs):
    name = name.title()
    url = f'https://github.com/CUAI/Non-Homophily-Large-Scale/raw/master/data/facebook100/{name}.mat'
    dataset_dir = f'{root}/facebook100/{name}'
    download_url(url, dataset_dir)

    mat = scipy.io.loadmat(f'{dataset_dir}/{name}.mat')
    features, edge_index, label = load_facebook100(mat)

    # already undirected edge_index
    trans_list = [T.ToUndirected(), T.ToSparseTensor()]

    if ptb_type is not None and ptb_ratio > 0:
        edge_index = load_ptb_edges(dataset_dir, edge_index, features.shape[0], ptb_type, ptb_ratio)
        trans_list.insert(0, T.RemoveDuplicatedEdges())

    data = Data(
        x=torch.tensor(features, dtype=torch.float),
        y=torch.tensor(label, dtype=torch.long),
        edge_index=edge_index,
    )
    transform = T.Compose(trans_list)
    data = transform(data)

    data = load_split(dataset_dir, data, train_rate, val_rate)

    num_features, num_classes = data.x.shape[1], (data.y.max() + 1).item()

    return data, num_features, num_classes, dataset_dir


def get_pokec(root: str, train_rate=0.5, val_rate=0.25, **kwargs):
    dataset_url = '1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y'

    os.makedirs(f'{root}/pokec', exist_ok=True)
    path = f'{root}/pokec/pokec.mat'
    if not osp.exists(path):
        gdown.download(id=dataset_url, output=path, quiet=False)
    ndata = scipy.io.loadmat(path)

    data = Data(
        x=torch.tensor(ndata['node_feat']).float(),
        y=torch.tensor(ndata['label'].flatten(), dtype=torch.long),
        edge_index=torch.tensor(ndata['edge_index'], dtype=torch.long),
    )
    transform = T.Compose([T.ToUndirected(), T.ToSparseTensor()])
    data = transform(data)

    num_features, num_classes = data.x.shape[1], (data.y.max() + 1).item()

    data = load_split(f'{root}/pokec', data, train_rate, val_rate)

    return data, num_features, num_classes, f'{root}/pokec'


def get_snap_patents(root: str, train_rate=0.5, val_rate=0.25, nclass=5, **kwargs):
    dataset_url = '1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia'

    os.makedirs(f'{root}/snap-patents', exist_ok=True)
    path = f'{root}/snap-patents/snap-patents.mat'
    if not osp.exists(path):
        gdown.download(id=dataset_url, output=path, quiet=False)
    ndata = scipy.io.loadmat(path)

    label = even_quantile_labels(ndata['years'].flatten(), nclass, verbose=False)
    data = Data(
        x=torch.tensor(ndata['node_feat'].todense(), dtype=torch.float),
        y=torch.tensor(label, dtype=torch.long),
        edge_index=torch.tensor(ndata['edge_index'], dtype=torch.long),
    )
    transform = T.Compose([T.ToSparseTensor()])  # directed graph
    data = transform(data)

    data = load_split(f'{root}/snap-patents', data, train_rate, val_rate)

    return data, data.x.shape[1], nclass, f'{root}/snap-patents'


def get_genius(root: str, train_rate=0.5, val_rate=0.25, **kwargs):
    url = f'https://github.com/CUAI/Non-Homophily-Large-Scale/raw/master/data/genius.mat'
    dataset_dir = f'{root}/genius'
    download_url(url, dataset_dir)
    ndata = scipy.io.loadmat(f'{dataset_dir}/genius.mat')

    undirected = kwargs.pop('undirected', True)

    data = Data(
        x=torch.tensor(ndata['node_feat'], dtype=torch.float),
        y=torch.tensor(ndata['label'].flatten(), dtype=torch.long),
        edge_index=torch.tensor(ndata['edge_index'], dtype=torch.long),
    )
    trans_list = [T.ToSparseTensor()]
    if undirected:
        trans_list.insert(0, T.ToUndirected())
    transform = T.Compose(trans_list)
    data = transform(data)

    data = load_split(dataset_dir, data, train_rate, val_rate)

    num_features, num_classes = data.x.shape[1], 1  # calculate AUC_ROC

    return data, num_features, num_classes, dataset_dir


def get_twitch_gamer(root: str, train_rate=0.5, val_rate=0.25, task='mature', normalize=True, **kwargs):
    # TODO: Google Drive URL is not available
    url_feat = '1fA9VIIEI8N0L27MSQfcBzJgRQLvSbrvR'
    url_edge = '1XLETC6dG3lVl7kDmytEJ52hvDMVdxnZ0'
    dataset_dir = f'{root}/twitch-gamer'
    os.makedirs(dataset_dir, exist_ok=True)

    if not osp.exists(f'{dataset_dir}/twitch-gamer_feat.csv'):
        gdown.download(id=url_feat, output=f'{dataset_dir}/twitch-gamer_feat.csv', quiet=False)
    if not osp.exists(f'{dataset_dir}/twitch-gamer_edges.csv'):
        gdown.download(id=url_edge, output=f'{dataset_dir}/twitch-gamer_edges.csv', quiet=False)
    edges = pd.read_csv(f'{dataset_dir}/twitch-gamer_edges.csv')
    nodes = pd.read_csv(f'{dataset_dir}/twitch-gamer_feat.csv')

    label, features = load_twitch_gamer(nodes, task)
    node_feat = torch.tensor(features, dtype=torch.float)
    if normalize:
        node_feat = node_feat - node_feat.mean(dim=0, keepdim=True)
        node_feat = node_feat / node_feat.std(dim=0, keepdim=True)

    data = Data(
        x=torch.tensor(node_feat, dtype=torch.float),
        y=torch.tensor(label),
        edge_index=torch.tensor(edges.to_numpy(), dtype=torch.long).t(),
    )
    transform = T.Compose([T.ToUndirected(), T.ToSparseTensor()])
    data = transform(data)

    data = load_split(f'{root}/twitch-gamer', data, train_rate, val_rate)

    num_features, num_classes = data.x.shape[1], (data.y.max() + 1).item()

    return data, num_features, num_classes, dataset_dir


def get_metric(name, num_classes):
    if name.lower() in ['yelp', 'ppi']:
        return torchmetrics.F1Score(task="multilabel", average='micro', num_labels=num_classes)
    elif name.lower() in ['minesweeper', 'tolokers', 'questions', 'genius']:
        return torchmetrics.AUROC(task="binary")
    else:
        return torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)


def get_data(name: str, root: str, **kwargs):
    """
    Returns: pyg Data, number of features, number of classes, processed directory
    """
    if name.lower() in ['cora', 'citeseer', 'pubmed']:
        return get_planetoid(root, name, **kwargs)
    elif name.lower() in ['computers', 'photo']:
        return get_amazon(root, name)
    elif name.lower() == 'wikics':
        return get_wikics(root)
    elif name.lower() in ['chameleon', 'squirrel']:
        return get_wikinet(root, name, **kwargs)
    elif name.lower() in ['roman_empire', 'amazon_ratings', 'minesweeper', 'tolokers', 'questions']:
        return get_heterophilious(root, name, **kwargs)
    elif name.lower() in ['cornell', 'texas', 'wisconsin']:
        return get_webkb(root, name)
    elif name.lower() in ['cluster', 'pattern']:
        return get_sbm(root, name)
    elif name.lower() == 'reddit':
        return get_reddit(root)
    elif name.lower() == 'ppi':
        return get_ppi(root, **kwargs)
    elif name.lower() == 'flickr':
        return get_flickr(root)
    elif name.lower() == 'yelp':
        return get_yelp(root)
    elif name.lower() in ['arxiv', 'products']:
        return get_ogbn(root, name)
    elif name.lower() == 'pokec':
        return get_pokec(root, **kwargs)
    elif name.lower() == 'arxiv-year':
        return get_arxiv_year(root, **kwargs)
    elif name.lower() in ['penn94', 'reed98', 'cornell5', 'amherst41']:
        return get_facebook100(root, name, **kwargs)
    elif name.lower() == 'snap-patents':
        return get_snap_patents(root, **kwargs)
    elif name.lower() == 'genius':
        return get_genius(root, **kwargs)
    elif name.lower() == 'twitch-gamers':
        return get_twitch_gamer(root, **kwargs)
    else:
        raise NotImplementedError
