import argparse
import torch
import numpy as np
import hashlib
import networkx as nx
from joblib import Parallel, delayed
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from torch_geometric.datasets import TUDataset
from torch_geometric.utils import to_networkx

from evaluate_embedding import evaluate_embedding


def parameter_parser():
    parser = argparse.ArgumentParser(description="Run Graph2Vec.")
    parser.add_argument('--dataset', type=str, default='MUTAG')
    parser.add_argument("--dimensions", type=int, default=128, help="Number of dimensions. Default is 128.")
    parser.add_argument("--workers", type=int, default=4, help="Number of workers. Default is 4.")
    parser.add_argument("--epochs", type=int, default=10, help="Number of epochs. Default is 10.")
    parser.add_argument("--min_count", type=int, default=5, help="Minimal structural feature count. Default is 5.")
    parser.add_argument("--wl_iterations", type=int, default=2, help="Number of Weisfeiler-Lehman iterations. Default is 2.")
    parser.add_argument("--learning_rate", type=float, default=0.025, help="Initial learning rate. Default is 0.025.")
    parser.add_argument("--down_sampling", type=float, default=0.0001, help="Down sampling rate of features. Default is 0.0001.")
    return parser.parse_args()


class WeisfeilerLehmanMachine:
    def __init__(self, graph, features, iterations):
        """
        Initialization method which also executes feature extraction.
        :param graph: The Nx graph object.
        :param features: Feature hash table.
        :param iterations: Number of WL iterations.
        """
        self.iterations = iterations
        self.graph = graph
        self.features = features
        self.nodes = self.graph.nodes()
        self.extracted_features = [str(v) for k, v in features.items()]
        self.do_recursions()

    def do_a_recursion(self):
        """
        The method does a single WL recursion.
        :return new_features: The hash table with extracted WL features.
        """
        new_features = {}
        for node in self.nodes:
            nebs = self.graph.neighbors(node)
            degs = [self.features[neb] for neb in nebs]
            features = [str(self.features[node])]+sorted([str(deg) for deg in degs])
            features = "_".join(features)
            hash_object = hashlib.md5(features.encode())
            hashing = hash_object.hexdigest()
            new_features[node] = hashing
        self.extracted_features = self.extracted_features + list(new_features.values())
        return new_features

    def do_recursions(self):
        """
        The method does a series of WL recursions.
        """
        for _ in range(self.iterations):
            self.features = self.do_a_recursion()


def dataset_reader(data):
    """
    Function to read the graph and features.
    :return graph: The graph object.
    :return features: Features hash table.
    """
    graph = to_networkx(data)
    # if data.x is not None:
    #     if args.dataset in ['MUTAG']:
    #         n_node = data.x.size(0)
    #         features = {int(i): torch.argmax(data.x[i]).item() for i in range(n_node)}
    #     else:
    #         print(data.x)
    features = nx.degree(graph)
    features = {int(k): v for k, v in features}
    return graph, features


def feature_extractor(data, rounds, idx):
    """
    Function to extract WL features from a graph.
    :param rounds: Number of WL iterations.
    :return doc: Document collection object.
    """
    graph, features = dataset_reader(data)
    machine = WeisfeilerLehmanMachine(graph, features, rounds)
    doc = TaggedDocument(words=machine.extracted_features, tags=[f"g_{idx}"])
    return doc


def main(args):
    args = parameter_parser()
    print(args)
    print("---------------------------------------------")

    dataset = TUDataset('data', name=args.dataset)
    print("Feature extraction started.")
    document_collections = Parallel(n_jobs=args.workers)(delayed(feature_extractor)(data, args.wl_iterations, idx) for idx, data in enumerate(dataset))
    print("Optimization started.")

    model = Doc2Vec(document_collections,
                    vector_size=args.dimensions,
                    window=0,
                    min_count=args.min_count,
                    dm=0,
                    sample=args.down_sampling,
                    workers=args.workers,
                    epochs=args.epochs,
                    alpha=args.learning_rate)

    embeds = np.zeros((len(dataset), args.dimensions))
    labels = np.zeros((len(dataset), 1))
    for idx, data in enumerate(dataset):
        vectors = model.dv[f"g_{idx}"]
        embeds[idx] = np.array(vectors)
        labels[idx] = data.y
    acc_mean, acc_std = evaluate_embedding(embeds, labels.ravel())
    print(f'test acc: {acc_mean:.4f} +- {acc_std:.4f}')

if __name__ == "__main__":
    args = parameter_parser()
    main(args)
