# %%
import os

import numpy as np
import torch
import torch.nn.functional as F
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import ZINC, Planetoid, Coauthor, Amazon

'''

python dataset_stats.py --name AmazonComputers --category Node
python dataset_stats.py --name AmazonPhoto --category Node
python dataset_stats.py --name GitHub --category Node
python dataset_stats.py --name FacebookPagePage --category Node
python dataset_stats.py --name CoauthorPhysics --category Node
python dataset_stats.py --name TwitchEN --category Node
python dataset_stats.py --name CoauthorCS --category Node
python dataset_stats.py --name DBLP --category Node
python dataset_stats.py --name PubMed --category Node
python dataset_stats.py --name Cora --category Node
python dataset_stats.py --name CiteSeer --category Node


python dataset_stats.py --name ogbn-arxiv --category Node
python dataset_stats.py --name ogbn-products --category Node
python dataset_stats.py --name ogbn-proteins --category Node

'''
import argparse

parser = argparse.ArgumentParser(
    description='Train a classification model'
)
parser.add_argument(
    '--name',
    dest='name',
    help='Dataset name',
    required=True,
    type=str
)

parser.add_argument(
    '--category',
    dest='category',
    help='Dataset category',
    required=True,
    type=str
)

parser.add_argument(
    '--print',
    dest='print',
    help='If print',
    type=bool,
    default=False
)

args = parser.parse_args()
name = args.name  # ogbg-molhiv | ogbg-molpcba | Cora | CS
category = args.category


def pprint(my_str):
    if args.print:
        print(my_str)


def get_nodes_per_graph(dataset):
    num_nodes = []

    for i, data in enumerate(dataset):
        if i > 10000: break
        num_nodes.append(data.num_nodes)

    return num_nodes


def get_edges_per_graph(dataset):
    num_edges = []

    for i, data in enumerate(dataset):
        if i > 10000: break
        num_edges.append(data.edge_index.shape[1])

    return num_edges


def get_num_graphs(dataset_raw_all, dataset_raw):
    if dataset_raw_all is None:
        return len(dataset_raw)
    else:
        return len(dataset_raw_all)


def get_node_feats(x):
    if x is None:
        return '-'
    else:
        if len(x.shape) == 1:
            return '1'
        else:
            return str(x.shape[1])


def get_edge_feats(edge_attr):
    if edge_attr is not None:
        edge_attr = edge_attr[:40000]
        if len(edge_attr.shape) > 1:
            return str(edge_attr.shape[1])
        else:
            return '1'
    else:
        return '-'


def get_num_tasks(y):
    try:
        if len(y.shape) == 1:
            return '1'
        else:
            return str(y.shape[1])
    except:
        return 'N/A'


def get_task_type(y):
    try:
        if len(y.shape) == 1 or (y.shape[1] == 1):
            y_tmp = y.flatten()
            unique_y = y.unique()
            assert len(unique_y) > 1
            if len(unique_y) == 2:
                return 'Binary clf.'
            else:
                return f'{len(unique_y)}-class clf.'

        else:
            return 'Multi-task'
    except:
        return "N/A"


dataset_raw_all = None
if name in ['ZINC']:

    root = os.path.join('.', 'datasets', 'ZINC')

    dataset_raw = ZINC(root=root, subset=True, split='train')
elif name in ['MalNetTiny']:
    from torch_geometric.datasets import MalNetTiny

    root = os.path.join('.', 'datasets', name)
    dataset_raw = MalNetTiny(root=root)
elif name in ['PPI']:
    from torch_geometric.datasets import PPI

    root = os.path.join('.', 'datasets', name)
    dataset_raw = PPI(root=root, split='train')
elif name in ['CIFAR10']:
    from torch_geometric.datasets import GNNBenchmarkDataset

    root = os.path.join('.', 'datasets', name)
    dataset_raw = GNNBenchmarkDataset(root=root, name=name, split='train')
elif name in ['GitHub']:
    from torch_geometric.datasets import GitHub

    root = os.path.join('.', 'datasets', name)
    dataset_raw = GitHub(root=root)
elif name in ['FacebookPagePage']:
    from torch_geometric.datasets import FacebookPagePage

    root = os.path.join('.', 'datasets', name)
    dataset_raw = FacebookPagePage(root=root)
elif name in ['CoauthorCS']:
    root = os.path.join('.', 'datasets', 'CoauthorCS')
    dataset_raw = Coauthor(root=root, name='CS')
elif name in ['CoauthorPhysics']:
    root = os.path.join('.', 'datasets', 'CoauthorPhysics')
    dataset_raw = Coauthor(root=root, name='Physics')
elif 'ogbn' in name:
    def my_transform(data):
        # data.edge_index = tgutils.to_undirected(data.edge_index)
        # data.y = data.y.flatten()
        return data


    dataset_raw = PygNodePropPredDataset(name=name, transform=my_transform)
elif 'ogbl' in name:
    from ogb.linkproppred import PygLinkPropPredDataset


    def my_transform(data):
        # data.edge_index = tgutils.to_undirected(data.edge_index)
        # data.y = data.y.flatten()
        return data


    dataset_raw = PygLinkPropPredDataset(name=name, transform=my_transform)

elif 'Twitch' in name:
    from torch_geometric.datasets import Twitch

    tmp = name.replace('Twitch', '')
    root = os.path.join('.', 'datasets', 'Twitch')
    dataset_raw = Twitch(root=root, name=tmp)
elif 'Airports' in name:
    from torch_geometric.datasets import Airports

    tmp = name.replace('Airports', '')
    root = os.path.join('.', 'datasets', 'Airports')
    dataset_raw = Airports(root=root, name=tmp)
elif 'Amazon' in name:
    root = os.path.join('.', 'datasets', name)
    if 'Computers' in name:
        dataset_raw = Amazon(root, name='Computers')
    else:
        dataset_raw = Amazon(root, name='Photo')
elif name in ['Cora', 'CiteSeer', 'PubMed']:
    root = os.path.join('.', 'datasets', name)
    dataset_raw = Planetoid(root=root, name=name)
elif name in ['DBLP']:
    from torch_geometric.datasets import CitationFull

    root = os.path.join('.', 'datasets', name)
    dataset_raw = CitationFull(root=root, name=name)
elif name in ['ogbg-molpcba', 'ogbg-molhiv']:
    dataset_raw_all = PygGraphPropPredDataset(name=name)
    original_len = len(dataset_raw_all)
    new_len = 10000
    perm = torch.randperm(original_len)
    dataset_raw = dataset_raw_all[perm[:new_len]]
elif name in ['QM7b']:
    from torch_geometric.datasets import QM7b

    root = os.path.join('.', 'datasets', name)
    dataset_raw = QM7b(root=root)
elif name in ['QM9']:
    from torch_geometric.datasets import QM9

    root = os.path.join('.', 'datasets', name)
    dataset_raw = QM9(root=root)
else:
    if name == 'IMDB':
        tmp = 'IMDB-MULTI'
    else:
        tmp = name
    root = os.path.join('.', 'datasets', f'TU_{tmp}')

    dataset_raw = TUDataset(root, tmp,
                            use_node_attr=True, use_edge_attr=True)

try:
    y_one_hot = torch.nn.functional.one_hot(dataset_raw.data.y, num_classes=- 1)
    if len(y_one_hot.shape) == 3: y_one_hot = y_one_hot.squeeze(1)
except:
    y_one_hot = None

y = dataset_raw.data.y

x = dataset_raw.data.x
pprint(f"Dataset: {name}")
n_graphs = get_num_graphs(dataset_raw_all, dataset_raw)
num_nodes = get_nodes_per_graph(dataset_raw)
num_edges = get_edges_per_graph(dataset_raw)
from torch_geometric.utils import is_undirected

graph_is_undirected = is_undirected(dataset_raw[0].edge_index)

num_node_feats = get_node_feats(x)
num_edge_feats = get_edge_feats(dataset_raw.data.edge_attr)

num_tasks = get_num_tasks(y)

task_type = get_task_type(y)

from torch_geometric.utils import degree

edge_index_rnd_1 = dataset_raw.data.edge_index[0, :]
edge_index_rnd_2 = dataset_raw.data.edge_index[0, :]
# if len(edge_index_rnd_1) > 40000:
#     perm  = torch.randperm(len(edge_index_rnd_1))
#     edge_index_rnd_1 = edge_index_rnd_1[perm[:40000]]
#     edge_index_rnd_2 = edge_index_rnd_2[perm[:40000]]
avg_degree_1 = degree(dataset_raw.data.edge_index[0, :], num_nodes=dataset_raw.data.num_nodes).mean().item()
avg_degree_2 = degree(dataset_raw.data.edge_index[1, :], num_nodes=dataset_raw.data.num_nodes).mean().item()

# %%

my_str_title = 'Category & Name '
my_str = f' {category}  & {name}'

pprint(f"Num. graphs: {n_graphs}")
my_str_title += ' &  Graphs'
my_str += f'  & {n_graphs}'

pprint(f"Avg. num nodes: {np.mean(num_nodes):.2f}+-{np.std(num_nodes):.2f}")
my_str_title += ' & \\#Nodes'
if np.std(num_nodes) == 0:
    my_str += f'  & {int(np.mean(num_nodes)):,}'
else:
    my_str += f'  & {np.mean(num_nodes):.2f} $\pm$ {np.std(num_nodes):.2f}'

pprint(f"Avg. num edges: {np.mean(num_edges):.2f} +- {np.std(num_edges):.2f}")
my_str_title += '  & \\#Edges'
if np.std(num_edges) == 0:
    my_str += f'  & {int(np.mean(num_edges)):,}'
else:
    my_str += f'  & {np.mean(num_edges):.2f} $\pm$ {np.std(num_edges):.2f}'

pprint(f"Avg. degree: {avg_degree_1}")
my_str_title += '  & Avg. degree'
my_str += f'  & {avg_degree_1:.2f}'

pprint(f"Is undirected? {graph_is_undirected}")
mark = '\\xmark' if graph_is_undirected else '\\cmark'
my_str_title += ' & Directed'
my_str += f'  & {mark}'

pprint(f"Num. node feats: {num_node_feats}")
my_str_title += '  & \\#Node feats.'
my_str += f'  & {num_node_feats}'

pprint(f"Num. edge feats: {num_edge_feats}")
my_str_title += '  & \\#Edge feats.'
my_str += f'  & {num_edge_feats}'

pprint(f"Num. tasks: {num_tasks}")
my_str_title += '  & \\#Tasks'
my_str += f'  & {num_tasks}'

pprint(f"Task type: {task_type}")
my_str_title += '  & Task Type'
my_str += f'  & {task_type}'

print(f"Number of cols in the table: {len(my_str_title.split('&'))}\n\n")
print(f"{my_str_title} \\\\ \\midrule")
print(f"{my_str} \\\\")

with open('dataset.tex', 'a') as f:
    f.write(f"{my_str} \\\\\n")

print('\n\n')
