from typing import (
    Tuple,
    List,
    Dict,
)

import networkx as nx
from loguru import logger
from pysat.solvers import Solver

from src.schema import (
    ConflictGraph,
    ColoredGraph,
)


class GraphColoringStage:
    def __init__(self) -> None:
        logger.info("Initializing GraphColoringStage.")

    def _dsatur_coloring(self, G: nx.Graph) -> Dict[int, int]:
        return nx.coloring.greedy_color(G, strategy='saturation_largest_first')

    def _is_k_colorable(self, G: nx.Graph, k: int) -> Tuple[bool, List[int]]:
        n = G.number_of_nodes()
        node_ids = list(G.nodes())
        node_index = {node: idx for idx, node in enumerate(node_ids)}

        def var(i, c): return i * k + c + 1

        clauses = []
        for i in range(n):
            clauses.append([var(i, c) for c in range(k)])
            clauses.extend([[-var(i, c1), -var(i, c2)] for c1 in range(k) for c2 in range(c1 + 1, k)])

        for u, v in G.edges():
            i, j = node_index[u], node_index[v]
            for c in range(k):
                clauses.append([-var(i, c), -var(j, c)])

        with Solver(name="minisat22") as solver:
            solver.append_formula(clauses)
            if solver.solve():
                return True, solver.get_model()
            else:
                return False, None

    def _find_min_k(self, G: nx.Graph, lower_bound: int, upper_bound: int) -> Tuple[int, List[int]]:
        final_k = None
        final_model = None
        while lower_bound < upper_bound:
            mid = (lower_bound + upper_bound) // 2
            logger.info(f"Checking if graph is {mid}-colorable.")
            success, model = self._is_k_colorable(G, mid)
            if success:
                logger.info(f"Found {mid}-colorable graph.")
                upper_bound = mid
                final_k = mid
                final_model = model
            else:
                logger.info(f"{mid}-colorable graph not found, trying higher k.")
                lower_bound = mid + 1
        return final_k, final_model

    def _build_colored_graph_from_model(
        self,
        conflict_graph: ConflictGraph,
        min_k: int,
        model: List[int],
    ) -> ColoredGraph:
        G = conflict_graph.to_networkx_graph()
        node_ids = list(G.nodes())

        def var(i, c): return i * min_k + c + 1

        node_id_to_color_id = {}
        for i, node_id in enumerate(node_ids):
            for c in range(min_k):
                if var(i, c) in model:
                    node_id_to_color_id[node_id] = c
                    break

        return ColoredGraph(
            n_nodes=conflict_graph.n_nodes,
            answer_id_to_node_id=conflict_graph.answer_id_to_node_id,
            node_id_to_answer_id=conflict_graph.node_id_to_answer_id,
            adjacency_dict=conflict_graph.adjacency_dict,
            n_colors=min_k,
            node_id_to_color_id=node_id_to_color_id,
        )

    def run(
        self,
        conflict_graph: ConflictGraph,
        known_clique_size: int,
    ) -> ColoredGraph:
        G = nx.Graph(conflict_graph.adjacency_dict)

        lower_bound = known_clique_size
        dsatur_result = self._dsatur_coloring(G=G)
        upper_bound = max(dsatur_result.values()) + 1

        # TODO: remove here
        model = None
        min_k = None
        # logger.info(f"Finding minimum k-coloring between {lower_bound} and {upper_bound}.")
        # min_k, model = self._find_min_k(G=G, lower_bound=lower_bound, upper_bound=upper_bound)

        if model is None:
            logger.warning("SAT solver failed. Falling back to DSATUR coloring.")
            n_colors = max(dsatur_result.values()) + 1
            node_id_to_color_id = dsatur_result
            return ColoredGraph(
                n_nodes=conflict_graph.n_nodes,
                answer_id_to_node_id=conflict_graph.answer_id_to_node_id,
                node_id_to_answer_id=conflict_graph.node_id_to_answer_id,
                adjacency_dict=conflict_graph.adjacency_dict,
                n_colors=n_colors,
                node_id_to_color_id=node_id_to_color_id,
            )
        else:
            logger.info(f"Minimum k-coloring found: {min_k}")
            colored_graph = self._build_colored_graph_from_model(
                conflict_graph=conflict_graph,
                min_k=min_k,
                model=model,
            )
            return colored_graph
