import logging
import os
import re
import subprocess
from collections import Counter
from multiprocessing import Pool
from tempfile import NamedTemporaryFile

from tqdm import tqdm

from graph_mixup.compute_ged.parser import Args, parse_args
from graph_mixup.compute_ged.typing import (
    GEDResult,
    MissingGEDException,
    MissingMappingException,
)
from graph_mixup.ged_database.handlers.ged_compute_database_handler import (
    GEDComputeDatabaseHandler,
)
from graph_mixup.ged_database.models import Graph


class GEDComputer:
    def __init__(self, timeout: int, lb_threshold: int) -> None:
        self.timeout = timeout
        self.lb_threshold = lb_threshold

    def _make_temp_file(self, g: Graph) -> str:
        with NamedTemporaryFile(mode="w", delete=False) as f:
            f.write(g.get_ged_library_format())
            return f.name

    def _remove_temp_file(self, name: str) -> None:
        os.remove(name)

    def process(self, g0: Graph, g1: Graph) -> GEDResult:
        # ===
        # Compute a lower bound. If it is above the threshold, skip GED
        # computation.
        # ===

        computed_lb = self._lower_bound(g0, g1)
        if computed_lb > self.lb_threshold:
            return GEDResult(g0.graph_id, g1.graph_id, -1, 0, None, computed_lb)

        # ===
        # Check that g0.num_nodes <= g1.num_nodes. If not, swap g0 and g1.
        # Reason: GED binary will swap graphs otherwise by itself without
        #  notification. If this occurred, the mapping would be inverse.
        # ===

        if g0.num_nodes() <= g1.num_nodes():
            return self._compute(g0, g1, computed_lb)
        else:
            return self._compute(g1, g0, computed_lb)

    def _lower_bound(self, g0: Graph, g1: Graph) -> int:
        num_nodes_diff = abs(g0.num_nodes() - g1.num_nodes())
        num_edges_diff = abs(g0.num_edges() - g1.num_edges())
        num_node_attrs_diff = self._num_node_attrs_diff(
            g0.node_attributes_with_default_value(),
            g1.node_attributes_with_default_value(),
        )

        return num_nodes_diff + num_edges_diff + num_node_attrs_diff

    @staticmethod
    def _num_node_attrs_diff(
        g0_attrs: dict[int, tuple[float, ...]],
        g1_attrs: dict[int, tuple[float, ...]],
    ) -> int:
        if len(g0_attrs) <= len(g1_attrs):
            smaller_attrs = g0_attrs
            larger_attrs = g1_attrs
        else:
            smaller_attrs = g1_attrs
            larger_attrs = g0_attrs

        lb_relabel_ops = 0

        # Count each attribute in the larger graph.
        larger_counter = Counter(larger_attrs.values())

        # ===
        # A lower bound of the number of node relabel ops is given as the
        # difference in node labels between the smaller graph and the larger
        # graph.
        # ===

        for attr in smaller_attrs.values():
            if larger_counter[attr] == 0:  # Counter is 0 for missing keys.
                lb_relabel_ops += 1
            else:
                larger_counter[attr] -= 1

        return lb_relabel_ops

    def _compute(self, g0: Graph, g1: Graph, lb: int) -> GEDResult:
        assert g0.num_nodes() <= g1.num_nodes()

        file0 = self._make_temp_file(g0)
        file1 = self._make_temp_file(g1)

        command = ["./ged", "-q", file0, "-d", file1, "-g"]

        try:
            process = subprocess.run(
                command,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                timeout=self.timeout,
            )
            output: str = process.stdout.decode()
            err_output: str = process.stderr.decode()

            # Extract GED
            ged_match = re.search(r"GED: (\d+)", output)
            if ged_match:
                ged = int(ged_match.group(1))
            else:
                raise MissingGEDException(
                    "GED value not found in output:"
                    + "\nSTDOUT:\n "
                    + output
                    + "\nSTDERR:\n "
                    + err_output
                )

            # Extract mapping
            mapping_match = re.search(r"Mapping: (.+)", output)
            if mapping_match:
                mapping: dict[int, int] = {}
                pairs = mapping_match.group(1).split(", ")
                for pair in pairs:
                    if "->" in pair:
                        q, g = map(int, pair.split(" -> "))
                        mapping[q] = g
            else:
                raise MissingMappingException(
                    "Mapping not found in output:"
                    + "\nSTDOUT:\n "
                    + output
                    + "\nSTDERR:\n "
                    + err_output
                )

            # ===
            # Extract total time. For some unknown reason, time is not always
            # present in the binary's output, hence None is also accepted here.
            # ===

            total_time_match = re.search(
                r"Total time: ([\d,]+) \(microseconds\)", output
            )
            total_time = (
                int(total_time_match.group(1).replace(",", ""))
                if total_time_match
                else None
            )

            return GEDResult(
                g0.graph_id, g1.graph_id, ged, total_time, mapping, lb
            )

        except subprocess.TimeoutExpired:
            return GEDResult(
                g0.graph_id,
                g1.graph_id,
                -1,
                self.timeout * 1_000_000,
                None,
                lb,
            )  # Time in microseconds, hence multiplied by 1e6.

        finally:
            self._remove_temp_file(file0)
            self._remove_temp_file(file1)


class GEDComputeManager:
    def __init__(self, args: Args) -> None:
        self.db_manager = GEDComputeDatabaseHandler()
        self.dataset_name = args.dataset_name
        self.n_cpus = args.n_cpus
        self.timeout = args.timeout
        self.lb_threshold = args.lb_threshold
        self.batch_size = args.batch_size
        self.method_name = args.method_name

    def compute_geds_and_store(self) -> None:
        if self.method_name is None:
            batches = self.db_manager.get_graph_pairs_without_ged(
                self.dataset_name, limit=self.batch_size
            )
        else:
            batches = self.db_manager.get_mixup_graph_pairs_without_ged(
                self.dataset_name, self.method_name, limit=self.batch_size
            )

        computer = GEDComputer(self.timeout, self.lb_threshold)
        results: list[GEDResult] = []
        for batch in tqdm(batches):
            with Pool(self.n_cpus) as p:
                results += p.starmap(
                    computer.process,
                    batch,
                )

            for result in results:
                self.db_manager.create_ged_result(result)
            results = []


if __name__ == "__main__":
    args = parse_args()
    logging.basicConfig(level=logging.INFO if args.verbose else logging.WARNING)
    manager = GEDComputeManager(args)
    manager.compute_geds_and_store()
