import datetime as dt
import json
import logging
import multiprocessing as mp
import os
import time
import typing as ty
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterable, List, Tuple

import networkx as nx
import networkx.algorithms.isomorphism as iso
import numpy as np

from egr.fsg import filtering, gaston
from egr.data.io import EgrDenseData
from egr.graph_utils import get_neighborhood_subgraph
from egr.util import normalize_path, save, save_features

LOG = logging.getLogger(__name__)


def read_json(data_dir: Path) -> Iterable[ty.Dict]:
    LOG.info('reading path %s', data_dir)
    for i, path in enumerate(data_dir.glob('*.json')):
        with path.open() as f:
            yield json.load(f)


def create_graph(j):
    nodes_data = j['nodes']
    G = nx.Graph()
    root = -1
    idToNode = {}
    for i, nD in enumerate(nodes_data):
        id = nD['id']
        idToNode[id] = i
        G.add_node(i)
        for k, v in nD.items():
            G.nodes[i][k] = v
        if 'self' in nD:
            root = i
    edgesData = j['links']
    for eD in edgesData:
        s, t = eD['source'], eD['target']
        G.add_edge(idToNode[s], idToNode[t], attr=eD)
    return G, root


def load_subgraphs(root_dir: Path) -> Tuple[List, List]:
    LOG.info('load_subgraphs() %s', root_dir)
    graphs, roots = [], []
    for data in read_json(root_dir.expanduser()):
        G, root = create_graph(data)
        if root == -1:
            continue
        graphs.append(G)
        roots.append(root)
    return graphs, roots


def apply_gaston_labels(G: nx.Graph, node_id: int):
    nx.set_node_attributes(G, 0, gaston.gastonLabelAttr)
    G.nodes[node_id][gaston.gastonLabelAttr] = 1


def node_match(n1, n2) -> bool:
    return n1[gaston.gastonLabelAttr] == n2[gaston.gastonLabelAttr]


def root_matcher(n1: str, n2: str) -> bool:
    return ((gaston.rootAttr in n1) == (gaston.rootAttr in n1)) and (
        n1[gaston.rootAttr] == n2[gaston.rootAttr]
    )


def make_feature(G1: nx.Graph, G2: nx.Graph) -> int:
    # matcher = iso.GraphMatcher(
    #     G1, G2, node_match=iso.categorical_node_match(gaston.rootAttr, True)
    # )
    # return 1.0 if matcher.subgraph_is_monomorphic() else 0.0
    m = iso.GraphMatcher(G1, G2, node_match=root_matcher)
    return 1.0 if m.subgraph_is_isomorphic() else 0.0


# class Annotator:
#     ba_node_count: int = 300

#     def __init__(self, G: nx.Graph, subgraphs: List[nx.Graph], dim: int):
#         self.G = G
#         self.subgraphs = subgraphs[:dim]
#         self.dim = dim
#         self.radii: List[int] = [min(nx.radius(s), 3) for s in subgraphs]

#     def __call__(self, node_id: int):
#         return self.annotate(node_id)

#     def annotate(self, node_id: int):
#         begin = time.time()
#         features = self._annotate(node_id)
#         dur = time.time() - begin
#         pid = os.getpid()
#         LOG.debug('[PID:%5d] annotated n=%3d in %.4fs', pid, node_id, dur)
#         return features

#     def _annotate(self, node_id: int):
#         x = np.zeros(self.dim) + 0.1
#         for i, G_s in enumerate(self.subgraphs):
#             G_c = get_neighborhood_subgraph(
#                 self.G.copy(), node_id, self.radii[i]
#             )
#             gaston.makeRootNode(G_c, node_id)
#             x[i] = make_feature(G_c, G_s)
#         return x


@dataclass
class Annotation:
    node_id: int
    dim: int
    iso: float


class AsyncAnnotator:
    ba_node_count: int = 300

    def __init__(self, G: nx.Graph, subgraphs: List[nx.Graph], dim: int):
        self.G = G
        self.subgraphs = subgraphs[:dim]
        self.dim = dim
        self.radii: List[int] = [min(nx.radius(s), 3) for s in subgraphs]

    def __call__(self, node_id: int, dim: int):
        return self.annotate(node_id, dim)

    def annotate(self, node_id: int, dim: int):
        H: nx.Graph = self.subgraphs[dim]
        for u in H.nodes():
            H.nodes[u][gaston.rootAttr] = False
        pattern_root = H.graph[gaston.rootAttr]
        H.nodes[pattern_root][gaston.rootAttr] = True

        # hops = min(nx.eccentricity(H, pattern_root), 3)
        hops = min(nx.radius(H), 3)
        G: nx.Graph = get_neighborhood_subgraph(self.G.copy(), node_id, hops)
        for u in G.nodes():
            G.nodes[u][gaston.rootAttr] = False
        G.nodes[node_id][gaston.rootAttr] = True

        gaston.makeRootNode(G, node_id)
        return Annotation(node_id, dim, make_feature(G, H))


def save_intermediate(args, sg):
    num_sg: int = len(sg)
    LOG.info('Generating %d intermediate subgraphs', num_sg)
    inter_dir: Path = normalize_path(args.fsg_dir)
    inter_dir.mkdir(parents=True, exist_ok=True)
    for i in range(num_sg):
        path: Path = inter_dir / f'{i:04d}.json'
        label = sg[i].graph['label']
        score = sg[i].graph['f1_score']
        LOG.debug(
            'Saving intermediate to %s (label:%d, score:%.4f)',
            path,
            label,
            score,
        )
        save(sg[i], path)


def perform_annotations(G: nx.Graph, sg: Dict[str, List[nx.Graph]], args):
    annotator = AsyncAnnotator(G, sg, args.data_dim)

    N: int = G.number_of_nodes()
    n_dims: int = args.data_dim
    LOG.info('Beginning annotations on %dx%d', N, n_dims)
    num_annotations: int = N * n_dims
    counter: mp.Value = mp.Value('i', 0, lock=True)
    matches: List[Annotation] = []

    begin = datetime.now()

    with mp.Pool(processes=args.nproc) as p:
        results = []
        for node_id in G.nodes():
            for dim in range(n_dims):
                results.append(p.apply_async(annotator, args=(node_id, dim)))

        last_checkpoint = begin
        while counter.value < num_annotations:
            for result in results:
                if result.ready():
                    matches.append(result.get())
                    with counter.get_lock():
                        counter.value += 1
                        now_time = datetime.now()
                        dur = now_time - last_checkpoint
                        if dur > dt.timedelta(minutes=5):
                            last_checkpoint = now_time
                            LOG.info(
                                'Annotated %d/%d, last:%s, total:%s',
                                counter.value,
                                num_annotations,
                                dur,
                                now_time - begin,
                            )
                            last_checkpoint = now_time
                else:
                    result.wait(0.1)
    X: np.ndarray = np.zeros(shape=(N, n_dims), dtype=np.float32)
    for annot in matches:
        X[annot.node_id, annot.dim] = annot.iso

    LOG.info('Finished annotations in %s', datetime.now() - begin)
    LOG.info('Saving features %s to %s', X.shape, args.output_feature_file)
    args.output_feature_file.parent.mkdir(parents=True, exist_ok=True)
    save_features(args.output_feature_file, X)


# def perform_annotations(G: nx.Graph, sg: Dict[str, List[nx.Graph]], args):
#     annotator = Annotator(G, sg, args.data_dim)

#     N: int = G.number_of_nodes()
#     X = np.zeros([N, args.data_dim])
#     LOG.info('Beginning annotations on %dx%d', N, args.data_dim)
#     begin = datetime.now()
#     with Pool(processes=args.nproc) as p:
#         X_fut = p.map(annotator.annotate, G.nodes())
#         for n in G.nodes():
#             try:
#                 X[n, :] = X_fut[n]
#             except ValueError as err:
#                 LOG.error('%s, num subgraphs=%d', err, len(sg))

#     LOG.info('Finished annotations in %s', datetime.now() - begin)
#     LOG.info('Saving features %s to %s', X.shape, args.output_feature_file)
#     args.output_feature_file.parent.mkdir(parents=True, exist_ok=True)
#     save_features(args.output_feature_file, X)


def main(args) -> Tuple[List[nx.Graph], List[int]]:
    G: nx.Graph = EgrDenseData.load_graph(args.input_graph_file)

    filter_begin = datetime.now()
    selected: Dict[str, List[nx.Graph]] = filtering.filter_graphs(G, args)
    filter_end = datetime.now()

    LOG.info('Finished filtering in %s', filter_end - filter_begin)

    if args.intermediate:
        save_intermediate(args, selected)

    annotation_begin = datetime.now()
    perform_annotations(G, selected, args)
    annotation_end = datetime.now()

    return {
        'filtering': {'begin': filter_begin, 'end': filter_end},
        'annotation': {'begin': annotation_begin, 'end': annotation_end},
    }
