import json
import logging
import multiprocessing as mp
import signal
import time
from collections import defaultdict
from datetime import datetime
from multiprocessing.pool import AsyncResult, Pool
from pathlib import Path
from typing import Dict, Iterable, List, Tuple


import networkx as nx
import numpy as np
from networkx.algorithms.isomorphism import ISMAGS
from sklearn.metrics import (
    f1_score,
    accuracy_score,
    precision_recall_fscore_support,
)
from torch import Tensor

from egr import graph_utils
from egr.fsg import gaston
from egr.parallel import compute_isomorphism
from egr.util import load_indices, load_labels, load_graph

LOG = logging.getLogger(__name__)


def read_json(data_dir: Path) -> Iterable[Dict]:
    LOG.debug('reading path %s', data_dir)
    for path in sorted(data_dir.glob('*.json')):
        yield json.load(path.open())


def create_graph(graph_data):
    nodes_data = graph_data['nodes']
    G = nx.Graph()
    root = -1
    id_to_node = {}
    for i, n_d in enumerate(nodes_data):
        id = n_d['id']
        id_to_node[id] = i
        G.add_node(i)
        for k, v in n_d.items():
            G.nodes[i][k] = v
        if 'self' in n_d and n_d['self'] == 1:
            root = i
            G.graph['explanation_for_node'] = n_d['original']
            G.graph['explained_label'] = n_d['label']
    edges_data = graph_data['links']
    for e_d in edges_data:
        s, t = e_d['source'], e_d['target']
        G.add_edge(id_to_node[s], id_to_node[t], attr=e_d)
    return G, root


def load_subgraphs(root_dir: Path) -> Tuple[List, List]:
    LOG.debug('load_subgraphs() %s', root_dir)
    graphs, roots = [], []
    for data in read_json(root_dir.expanduser()):
        G, root = create_graph(data)
        if root == -1:
            continue
        gaston.makeRootNode(G, root)
        graphs.append(G)
        roots.append(root)
    return graphs, roots


def make_partitions(graphs: List[nx.Graph], roots: List[int]) -> Dict:
    assert len(graphs) == len(roots)
    partitions = {}
    for i, G in enumerate(graphs):
        label = G.graph['explained_label']
        if label not in partitions:
            partitions.update({label: {'graphs': [], 'roots': []}})
        partitions[label]['graphs'].append(G)
        partitions[label]['roots'].append(roots[i])
    return partitions


def indices_for_labels(indices: np.array, labels: Tensor) -> Dict:
    data = {}
    label_array = labels.tolist()
    for idx in sorted(indices.tolist()):
        label = label_array[idx]
        if label not in data:
            data.update({label: []})
        data[label].append(idx)
    return data


Results = List[AsyncResult]

POOL_ARGS = dict(
    initializer=signal.signal, initargs=(signal.SIGINT, signal.SIG_IGN)
)


def compute_scores(
    G: nx.Graph,
    H: nx.Graph,
    training_indices: List[int],
    target_label: int,
    labels: List[int],
    args,
) -> float:
    true_labels = [target_label == labels[idx] for idx in training_indices]
    pred_labels, num_timeouts = compute_isomorphism(
        G, H, training_indices, timeout=args.timeout_secs
    )

    LOG.debug('Using average=%s for computing metrics', args.average_strategy)
    precision, recall, fscore, support = precision_recall_fscore_support(
        true_labels,
        pred_labels,
        average=args.average_strategy,
        zero_division=0,
        pos_label=True,
    )
    return {
        'accuracy': accuracy_score(true_labels, pred_labels),
        'precision': precision,
        'recall': recall,
        'f1_score': fscore,
        'support': support,
        'f1_score-binary': f1_score(true_labels, pred_labels, pos_label=True),
        'num_timeouts': num_timeouts,
        'num_indices': len(training_indices),
    }


def pick_candidates_round_robin(
    data: Dict[int, List[Dict]], data_dim: int
) -> List[nx.Graph]:
    candidates = []
    assign_round = 0
    labels = sorted(data.keys())
    max_dim = min(sum([len(s) for _, s in data.items()]), data_dim)
    while len(candidates) < max_dim:
        for label in labels:
            scores = data[label]
            if assign_round < len(scores):
                candidates.append(scores[assign_round]['graph'])
            if len(candidates) == data_dim:
                break
        assign_round += 1
    LOG.info('labels=%s', [G.graph['label'] for G in candidates])
    return candidates


def get_prev_iter_fsg(dirpath: Path) -> Dict:
    prev_map = defaultdict(list)
    if dirpath is None:
        return prev_map
    LOG.info('Getting previous iteration patterns from %s', dirpath)
    for path in dirpath.glob('*.json'):
        G = load_graph(path)
        prev_map[G.graph['label']].append({'graph': G, **G.graph})
    return prev_map


def order_scores(l: List) -> List:
    return sorted(l, key=lambda x: x['f1_score'], reverse=True)


def stringify_scores(s_list: List) -> str:
    return ','.join(['{:.3f}'.format(s['f1_score']) for s in s_list])


def filter_graphs(G: nx.Graph, args) -> Dict[str, List[nx.Graph]]:
    graphs, roots = load_subgraphs(args.data_root)
    indices = load_indices(args.index_file)
    labels = load_labels(args.input_label_file.open()).tolist()
    partitions = make_partitions(graphs, roots)
    prev_scores = get_prev_iter_fsg(args.prev_fsg_dir)

    LOG.debug('loaded previous scores')
    for label, v in prev_scores.items():
        f1scores = ['{:.3f}'.format(item['f1_score']) for item in v]
        LOG.debug('label:%d, f1-score:%s', label, ','.join(f1scores))

    all_subgraphs = []
    LOG.info('Computing frequent subgraphs for %d labels', len(partitions))
    for label, partition in sorted(partitions.items()):
        subgraphs, roots = gaston.mineFreqRootedSubgraphs(
            graphs=partition['graphs'],
            roots=partition['roots'],
            label=label,
            args=args,
        )
        all_subgraphs.append(subgraphs)

    with_scores = {}
    for label, subgraphs in enumerate(all_subgraphs):
        scores = prev_scores[label]
        LOG.debug(
            'Computing F1-Scores for label %d from %d subgraphs, %d previous',
            label,
            len(subgraphs),
            len(scores),
        )
        begin_label = datetime.now()
        graphs_for_label = [item['graph'] for item in scores]
        for i, G_s in enumerate(subgraphs):
            if graph_utils.has_element(graphs_for_label, G_s):
                LOG.debug(
                    'Found pattern %s in previous collection, dropping duplicate',
                    G_s,
                )
                continue
            begin = time.time()
            metrics = compute_scores(
                G, G_s, indices['train'], label, labels, args
            )
            dur = time.time() - begin
            G_s.graph['label'] = label
            G_s.graph.update(**metrics)
            scores.append({'graph': G_s, **metrics})
            LOG.debug(
                '[%02d], L=%d, A:%.3f, P:%.3f, R:%.3f, F1=%.3f, dur=%3.3f',
                i,
                label,
                metrics['accuracy'],
                metrics['precision'],
                metrics['recall'],
                metrics['f1_score'],
                dur,
            )
        dur_label = datetime.now() - begin_label
        sorted_scores = order_scores(scores)
        LOG.info(
            'Label=%d(current:%02d,prev:%02d), dur:%s, scores:%s',
            label,
            len(scores),
            len(prev_scores),
            dur_label,
            stringify_scores(sorted_scores),
        )
        with_scores.update({label: sorted_scores})
    return pick_candidates_round_robin(with_scores, args.data_dim)
