# %%
import os

import numpy as np
import torch
import torch.nn.functional as F
from ogb.graphproppred import PygGraphPropPredDataset
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import ZINC, Planetoid, Coauthor, Amazon

'''

python dataset_stats_hetero.py --name ogbn-mag --category Node

'''
import argparse

parser = argparse.ArgumentParser(
    description='Train a classification model'
)
parser.add_argument(
    '--name',
    dest='name',
    help='Dataset name',
    required=True,
    type=str
)

parser.add_argument(
    '--category',
    dest='category',
    help='Dataset category',
    required=True,
    type=str
)

parser.add_argument(
    '--print',
    dest='print',
    help='If print',
    type=bool,
    default=False
)

args = parser.parse_args()
name = args.name  # ogbg-molhiv | ogbg-molpcba | Cora | CS
category = args.category


def pprint(my_str):
    if args.print:
        print(my_str)


def get_nodes_per_graph(dataset):
    num_nodes = []

    for i, data in enumerate(dataset):
        if i > 10000: break
        num_nodes.append(data.num_nodes)

    return num_nodes


def get_edges_per_graph(dataset):
    num_edges = []

    for i, data in enumerate(dataset):
        if i > 10000: break
        num_edges.append(data.edge_index.shape[1])

    return num_edges


def get_num_graphs(dataset_raw_all, dataset_raw):
    if dataset_raw_all is None:
        return len(dataset_raw)
    else:
        return len(dataset_raw_all)


def get_node_feats(x):
    if x is None:
        return '-'
    else:
        if len(x.shape) == 1:
            return '1'
        else:
            return str(x.shape[1])


def get_edge_feats(edge_attr):
    if edge_attr is not None:
        edge_attr = edge_attr[:40000]
        if len(edge_attr.shape) > 1:
            return str(edge_attr.shape[1])
        else:
            return '1'
    else:
        return '-'


def get_num_tasks(y):
    try:
        if len(y.shape) == 1:
            return '1'
        else:
            return str(y.shape[1])
    except:
        return 'N/A'


def get_task_type(y):
    try:
        if len(y.shape) == 1 or (y.shape[1] == 1):
            y_tmp = y.flatten()
            unique_y = y.unique()
            assert len(unique_y) > 1
            if len(unique_y) == 2:
                return 'Binary clf.'
            else:
                return f'{len(unique_y)}-class clf.'

        else:
            return 'Multi-task'
    except:
        return "N/A"


dataset_raw_all = None
if 'ogbn' in name:
    def my_transform(data):
        # data.edge_index = tgutils.to_undirected(data.edge_index)
        # data.y = data.y.flatten()
        return data


    dataset_raw = PygNodePropPredDataset(name=name, transform=my_transform)
else:
    raise NotImplementedError

assert len(dataset_raw) == 1
graph = dataset_raw[0]

stats = {'category': None,
         'name': None,
         'num_graphs': 1,
    'num_nodes': None,
         'avg_degree': None,
         'directed': None,
         'num_node_features': None,
         'num_edge_features': None,
         'num_tasks': None,
         'task_type': None}
# %%
num_nodes_dict = graph.num_nodes_dict
edge_index_dict = graph.edge_index_dict
x_dict = graph.x_dict
node_year = graph.node_year
edge_reltype = graph.edge_reltype
y_dict = graph.y_dict
for key, value in graph:
    print(f"{key}: {type(value)}")
    for key2, value2 in value.items():
        if isinstance(value2, int):
            print(f"\t{key2}: {value2}")
        elif isinstance(value2, torch.Tensor):
            print(f"\t{key2}: {value2.shape}")
        elif isinstance(value2, dict):
            print(f"\t{key2}: {type(value2)}")



pprint(f"Dataset: {name}")

# %% num graphs
n_graphs = 1
# %%
y = graph['y_dict']['paper']

try:
    y_one_hot = torch.nn.functional.one_hot(y, num_classes=- 1)
    if len(y_one_hot.shape) == 3: y_one_hot = y_one_hot.squeeze(1)
except:
    y_one_hot = None

# %% Num nodes

num_nodes = 0
for key, value in graph['num_nodes_dict'].items():
    num_nodes += value
num_nodes = [num_nodes]

# %% Num edges


num_edges = 0
for key, edge_index in graph['edge_index_dict'].items():
    num_edges += edge_index.shape[1]
num_edges = [num_edges]



# %%



from torch_geometric.utils import is_undirected

graph_is_undirected = is_undirected(edge_index)

# %%
x = graph['x_dict']['paper']
num_node_feats = get_node_feats(x)

# %%

num_edge_feats = 0
edge_index_list = []
for key, edge_index in graph['edge_index_dict'].items():
    edge_index_list.append(edge_index)
edge_index = torch.cat(edge_index_list, dim=1)
num_edges = [edge_index.shape[1]]

# %%

edge_attr = torch.zeros((100, len(graph['edge_reltype'])))


num_edge_feats = get_edge_feats(edge_attr)


# %%
num_tasks = get_num_tasks(y)

# %%
task_type = get_task_type(y)

from torch_geometric.utils import degree

edge_index_rnd_1 = edge_index[0, :]
edge_index_rnd_2 = edge_index[0, :]
# if len(edge_index_rnd_1) > 40000:
#     perm  = torch.randperm(len(edge_index_rnd_1))
#     edge_index_rnd_1 = edge_index_rnd_1[perm[:40000]]
#     edge_index_rnd_2 = edge_index_rnd_2[perm[:40000]]
avg_degree_1 = degree(edge_index[0, :]).mean().item()
avg_degree_2 = degree(edge_index[1, :]).mean().item()

# %%

my_str_title = 'Category & Name '
my_str = f' {category}  & {name}'

pprint(f"Num. graphs: {n_graphs}")
my_str_title += ' &  Graphs'
my_str += f'  & {n_graphs}'

pprint(f"Avg. num nodes: {np.mean(num_nodes):.2f}+-{np.std(num_nodes):.2f}")
my_str_title += ' & \\#Nodes'
if np.std(num_nodes) == 0:
    my_str += f'  & {int(np.mean(num_nodes)):,}'
else:
    my_str += f'  & {np.mean(num_nodes):.2f} $\pm$ {np.std(num_nodes):.2f}'

pprint(f"Avg. num edges: {np.mean(num_edges):.2f} +- {np.std(num_edges):.2f}")
my_str_title += '  & \\#Edges'
if np.std(num_edges) == 0:
    my_str += f'  & {int(np.mean(num_edges)):,}'
else:
    my_str += f'  & {np.mean(num_edges):.2f} $\pm$ {np.std(num_edges):.2f}'

pprint(f"Avg. degree: {avg_degree_1}")
my_str_title += '  & Avg. degree'
my_str += f'  & {avg_degree_1:.2f}'

pprint(f"Is undirected? {graph_is_undirected}")
mark = '\\xmark' if graph_is_undirected else '\\cmark'
my_str_title += ' & Directed'
my_str += f'  & {mark}'

pprint(f"Num. node feats: {num_node_feats}")
my_str_title += '  & \\#Node feats.'
my_str += f'  & {num_node_feats}'

pprint(f"Num. edge feats: {num_edge_feats}")
my_str_title += '  & \\#Edge feats.'
my_str += f'  & {num_edge_feats}'

pprint(f"Num. tasks: {num_tasks}")
my_str_title += '  & \\#Tasks'
my_str += f'  & {num_tasks}'

pprint(f"Task type: {task_type}")
my_str_title += '  & Task Type'
my_str += f'  & {task_type}'

print(f"Number of cols in the table: {len(my_str_title.split('&'))}\n\n")
print(f"{my_str_title} \\\\ \\midrule")
print(f"{my_str} \\\\")

with open('dataset.tex', 'a') as f:
    f.write(f"{my_str} \\\\\n")

print('\n\n')
