#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8

import torch

import os
import os.path as osp
import pickle
import torch_geometric.transforms as T

# from cSBM_dataset import dataset_ContextualSBM
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import Actor
from torch_sparse import coalesce
from torch_geometric.data import InMemoryDataset, download_url, Data
from torch_geometric.utils.undirected import to_undirected
from utils import sparse_tensor_to_edge_index
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.utils import homophily
from dataset_iclr23 import HeterophiliousDataset
from dataset_www21 import NonHomophiliousDataset


class dataset_heterophily(InMemoryDataset):
    def __init__(self, root='data/', name=None,
                 p2raw=None,
                 train_percent=0.01,
                 transform=None, pre_transform=None):

        existing_dataset = ['chameleon', 'film', 'squirrel']
        if name not in existing_dataset:
            raise ValueError(
                f'name of hypergraph dataset must be one of: {existing_dataset}')
        else:
            self.name = name

        self._train_percent = train_percent

        if (p2raw is not None) and osp.isdir(p2raw):
            self.p2raw = p2raw
        elif p2raw is None:
            self.p2raw = None
        elif not osp.isdir(p2raw):
            raise ValueError(
                f'path to raw hypergraph dataset "{p2raw}" does not exist!')

        if not osp.isdir(root):
            os.makedirs(root)

        self.root = root

        super(dataset_heterophily, self).__init__(
            root, transform, pre_transform)

        self.data, self.slices = torch.load(self.processed_paths[0])
        self.train_percent = self.data.train_percent

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        file_names = [self.name]
        return file_names

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        p2f = osp.join(self.raw_dir, self.name)
        with open(p2f, 'rb') as f:
            data = pickle.load(f)
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return '{}()'.format(self.name)


class WebKB(InMemoryDataset):
    r"""The WebKB datasets used in the
    `"Geom-GCN: Geometric Graph Convolutional Networks"
    <https://openreview.net/forum?id=S1e2agrFvS>`_ paper.
    Nodes represent web pages and edges represent hyperlinks between them.
    Node features are the bag-of-words representation of web pages.
    The task is to classify the nodes into one of the five categories, student,
    project, course, staff, and faculty.

    Args:
        root (string): Root directory where the dataset should be saved.
        name (string): The name of the dataset (:obj:`"Cornell"`,
            :obj:`"Texas"` :obj:`"Washington"`, :obj:`"Wisconsin"`).
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
    """

    url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/'
           'master/new_data')

    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name.lower()
        assert self.name in ['cornell', 'texas', 'washington', 'wisconsin']

        super(WebKB, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        return ['out1_node_feature_label.txt', 'out1_graph_edges.txt']

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        for name in self.raw_file_names:
            download_url(f'{self.url}/{self.name}/{name}', self.raw_dir)

    def process(self):
        with open(self.raw_paths[0], 'r') as f:
            data = f.read().split('\n')[1:-1]
            x = [[float(v) for v in r.split('\t')[1].split(',')] for r in data]
            x = torch.tensor(x, dtype=torch.float)

            y = [int(r.split('\t')[2]) for r in data]
            y = torch.tensor(y, dtype=torch.long)

        with open(self.raw_paths[1], 'r') as f:
            data = f.read().split('\n')[1:-1]
            data = [[int(v) for v in r.split('\t')] for r in data]
            edge_index = torch.tensor(data, dtype=torch.long).t().contiguous()
            edge_index = to_undirected(edge_index)
            edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))

        data = Data(x=x, edge_index=edge_index, y=y)
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return '{}()'.format(self.name)


def DataLoader(name):
    root = os.path.join(os.environ.get('HOME'), 'datasets/graph')
    if 'cSBM' in name:
        path = '../data/'
        dataset = dataset_ContextualSBM(path, name=name)
        return dataset, dataset[0]
    else:
        name = name.lower()

    if name in ['cora', 'citeseer', 'pubmed']:
        # path = osp.join(root, name)
        dataset = Planetoid(root, name, transform=T.Compose([T.NormalizeFeatures(), T.LargestConnectedComponents()]))
    elif name in ['computers', 'photo']:
        root_path = './'
        path = osp.join(root_path, 'data', name)
        dataset = Amazon(path, name, T.NormalizeFeatures())
    elif name in ['chameleon', 'squirrel']:
        # use everything from "geom_gcn_preprocess=False" and
        # only the node label y from "geom_gcn_preprocess=True"
        preProcDs = WikipediaNetwork(
            root=root, name=name, geom_gcn_preprocess=False, transform=T.NormalizeFeatures())
        dataset = WikipediaNetwork(
            root=root, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
        data = dataset[0]
        data.edge_index = preProcDs[0].edge_index
        return dataset, data

    elif name in ['film']:
        dataset = Actor(
            # root=osp.join(root, 'film', name), transform=T.NormalizeFeatures())
            root=osp.join(root, name), transform=T.NormalizeFeatures())
    elif name in ['texas', 'cornell']:
        dataset = WebKB(root=root,
                        name=name, transform=T.NormalizeFeatures())
    elif name in ['ogbn_arxiv']:
        path = osp.join(root, 'OGB')
        print(path)
        dataset = PygNodePropPredDataset(root=path, name='ogbn-arxiv',
                                        transform=T.ToSparseTensor())

        data = dataset[0]
        data.adj_t = data.adj_t.to_symmetric()
        data.edge_index = sparse_tensor_to_edge_index(data.adj_t)
        data.adj_t = None
        data.y = data.y.squeeze()
        return dataset, data
    elif name in ['chameleon_filtered', 'squirrel_filtered', 'wiki_cooc', 'roman_empire', 'amazon_ratings', 'minesweeper', 'workers', 'questions']:
        # dataset = HeterophiliousDataset(root='../data/', name=name, transform=T.NormalizeFeatures())
        dataset = HeterophiliousDataset(root=root, name=name)
    elif name in ['snap_patents', 'arxiv_year']:
        dataset = NonHomophiliousDataset(root=root, name=name)
    else:
        raise ValueError(f'dataset {name} not supported in dataloader')

    return dataset, dataset[0]


def print_statistics(dataset, data):
    print(f'Num classes: {dataset.num_classes}')
    print(f'Num features: {dataset.num_features}')
    print(f'Num nodes: {data.x.size(0)}')
    print(f'Num_edges (Processed): {data.edge_index.size(1)}')
    node_homophily_ratio = homophily(data.edge_index, data.y, method='node')
    edge_homophily_ratio = homophily(data.edge_index, data.y, method='edge')
    edge_insensitive_homophily_ratio = homophily(data.edge_index, data.y, method='edge_insensitive')
    print(f"Node homophily_ratio: {node_homophily_ratio:.4f}")
    print(f"Edge homophily_ratio: {edge_homophily_ratio:.4f}")
    print(f"Edge insensitive homophily_ratio: {edge_insensitive_homophily_ratio:.4f}")