from collections import defaultdict
import json
import logging
import re
from pathlib import Path
from typing import Dict

import networkx as nx
import pandas as pd

from egr.config import Config
from egr.data.io import EgrDenseData
from egr.util import load_graph, save_json
from egr.graph_utils import find_house_handles
from apps.hist.groundtruth import HouseMaker

LOG = logging.getLogger('prefect')

JSON_PATH_PATTERN = r'id-(?P<node_id>\d+)'
JSON_PATH_REGEX = re.compile(JSON_PATH_PATTERN)


def run_hist_maker(args):
    cfg: Config = Config.load(args.config_file)
    LOG.info('Searching %s', cfg.explainer_outputs)
    data = collect_per_node_match_counts(cfg)


def get_num_matching_nodes(g1: nx.Graph, g2: nx.Graph) -> int:
    count = 0
    for n in g1.nodes:
        if g2.has_node(n):
            count += 1
    return count


def get_num_matching_edges(g1: nx.Graph, g2: nx.Graph) -> int:
    count = 0
    for u, v in g1.edges:
        if g2.has_edge(u, v) or g2.has_edge(v, u):
            count += 1
    return count


def make_hist(cfg: Config, data: Dict):
    ids = set([])

    G_o = load_graph(data['input_graph'])
    random_count: int = data['random_count']
    G_o.graph['random_count'] = random_count
    hm = HouseMaker(G_o.number_of_nodes(), random_count)
    node_match_count, edge_match_count = [], []
    for path in cfg.input_dir.rglob('*.json'):
        r_match = JSON_PATH_REGEX.search(str(path))
        if r_match is None:
            continue

        G = load_graph(path)
        if G.graph['label'] != G.graph['pred']:
            LOG.warning('Node for %s was mispredicted, skipping', path)
            continue
        node_id = int(G.graph['node_idx'])
        # node_id = int(r_match.group('node_id'))
        if node_id < random_count:
            continue
        ids.add(node_id)
        handles = find_house_handles(G_o, node_id)
        G_exp = make_graph(path)
        G_gt = hm.make_house_with_handle(node_id, handles)
        node_match_count.append(get_num_matching_nodes(G_exp, G_gt))
        edge_match_count.append(get_num_matching_edges(G_exp, G_gt))
    df: pd.DataFrame = pd.DataFrame(
        data={
            'node_matches': node_match_count,
            'edge_matches': edge_match_count,
        }
    )
    cfg.output_dir.mkdir(parents=True, exist_ok=True)

    json_path = cfg.output_dir / f'{cfg.run_id}_{cfg.variant}.json'
    LOG.info('Saving to %s', json_path)
    with json_path.open('w') as f:
        df.to_json(f)


def make_graph(path):
    with path.open() as f:
        data = json.load(f)
        id_map = {d['id']: d['original'] for d in data['nodes']}
        nodes = id_map.values()
        edges = [
            (id_map[e['source']], id_map[e['target']]) for e in data['links']
        ]
        G: nx.Graph() = nx.Graph()
        G.add_nodes_from(nodes)
        G.add_edges_from(edges)
        return G


def collect_per_node_match_counts(cfg):
    data = defaultdict(list)

    for path in sorted(cfg.explainer_outputs.glob(f'id*.json')):
        match = JSON_PATH_REGEX.search(path.stem)
        if match is None:
            raise RuntimeError(
                f'name {path.stem} did not match {JSON_PATH_PATTERN}'
            )
        node_id: int = int(match.group('node_id'))
        if node_id < cfg.ba_count:
            continue

        data[node_id].append(
            {
                'sample_id': cfg.sample_id,
                **compute_match_counts(node_id=node_id, path=path, cfg=cfg),
            }
        )

    cfg.hist_output_dir.mkdir(parents=True, exist_ok=True)
    hist_file_path = cfg.hist_output_dir / f'{cfg.sample_id}.json'
    LOG.info('Saving metrics to %s', hist_file_path)
    save_json(data, hist_file_path)


def compute_match_counts(node_id: int, path: Path, cfg) -> Dict:
    data = EgrDenseData.read_new(
        cfg.input_graph_path, cfg.input_label_path, cfg.input_feature_path
    )
    G_o: nx.Graph = data.G
    G_o.graph['random_count'] = cfg.random_count
    hm = HouseMaker(G_o.number_of_nodes(), cfg.random_count)

    handles = find_house_handles(G_o, node_id)
    G_exp = make_graph(path)
    G_gt = hm.make_house_with_handle(node_id, handles)

    node_match_count: int = get_num_matching_nodes(G_exp, G_gt)
    edge_match_count: int = get_num_matching_edges(G_exp, G_gt)

    counts = {
        'node_match_count': node_match_count,
        'edge_match_count': edge_match_count,
    }

    return counts
