import hashlib
from dataclasses import dataclass
from typing import Dict, FrozenSet, Tuple, List, Optional

from datetime import datetime

from molecule_movement.logging import log_and_raise

from pytictoc import TicToc

from shapely.strtree import STRtree
from shapely import intersects  # Shapely 2.x
from shapely import prepare as _prepare_geom


try:
    from z3 import *
except ModuleNotFoundError as e:
    msg = "z3 is not installed. In order to use SATBasedScheduling, please install z3: \n\n\tpip install z3-solver\n"
    log_and_raise(e, msg)

try:
    import networkx as nx
except Exception as e:
    print(e)

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib

import numpy as np

import os

from loguru import logger

from molecule_movement import Matching, Molecule, Goal, Obstacle
from molecule_movement.scheduling import AbstractScheduling
from molecule_movement.shapes import compute_corridor
from molecule_movement.wrapper import CorridorConfiguration

@dataclass
class NestedClash:
    inner: str
    outer: str

def point_sig(pt) -> bytes:
    return hashlib.blake2b(pt.wkb, digest_size=16).digest()

@dataclass(frozen=True)
class SubproblemKey:
    pairs: FrozenSet[Tuple[str, bytes]]

SatScheduleCache = dict[SubproblemKey, List[str]]

class SATBasedScheduling(AbstractScheduling):
    def __init__(self,
                 molecules: list[Molecule],
                 goals: list[Goal],
                 obstacles: list[Obstacle],
                 matching: list[Matching],
                 corridor_config: CorridorConfiguration,
                 separate_conflicts: bool = True
                 ):
        super().__init__(molecules, goals, matching)
        self.obstacles = obstacles
        self.corridor_config = corridor_config
        self.molecule_names = [m.name for m in self.molecules]
        self.separate_conflicts = separate_conflicts
        self.matching = matching
        self.molecules_names_within_matching = [match.molecule.name for match in self.matching]

        self._by_name = {m.molecule.name: m for m in self.matching}

        self._mol_id = {m.name: i for i, m in enumerate(self.molecules)}
        self._goal_id = {}
        gid = 0
        for g in self.goals:
            key = (float(g.position.x), float(g.position.y))
            if key not in self._goal_id:
                self._goal_id[key] = gid
                gid += 1

        self.__compute_conflicts()
        if self.separate_conflicts:
            self.matchings = list()
            self.separated_conflicts = self.__separate_conflicts()

    def compute_schedule(self,
                         seed: Optional[int] = None,
                         graph_storage: Optional[str] = None,
                         cache: Optional[SatScheduleCache] = None,
                         cache_sat_schedules: bool = True) -> tuple[Optional[list[Matching]], Optional[dict[str, Matching]]]:
        info = dict()
        info["nested_clashes"] = False
        info["conflict"] = list()
        info["clashes"] = list()
        if cache_sat_schedules and cache is None:
            cache = SatScheduleCache()
        if len(self.goal_conflicts) == 0 and len(self.start_conflicts) == 0:
            return self.matching, None, info
        if self.separate_conflicts:
            schedule = [self.__find_matching(name) for name in self.nonconflicting_molecule_names if name in self.molecules_names_within_matching]
            conflicts = dict()

            t = TicToc()
            t.tic()
            for molecule_names, start_conflicts, goal_conflicts in self.separated_conflicts:
                sched, conf, cache = self.__run_solver(list(molecule_names), start_conflicts, goal_conflicts, cache)
                if sched: schedule.extend(sched)
                if conf:
                    conflicts.update(conf)
                    info["conflict"].append([str(c) for c in conf])
            info["sat_schedule_cache"] = cache
            time_needed = round(t.tocvalue(),3)
            logger.trace(f"{time_needed} for SMT run")
            if len(conflicts) == 0:
                return schedule, None, info
            else:
                return None, conflicts, info
        else:
            return self.__run_solver(self.molecule_names, self.start_conflicts, self.goal_conflicts)

    @property
    def conflicting_molecules_names(self) -> list[str]:
        return list(self._conflicting_molecules_names_set)

    @property
    def _conflicting_molecules_names_set(self) -> set[str]:
        return set(sum(self.start_conflicts + self.goal_conflicts, ()))

    @property
    def nonconflicting_molecule_names(self) -> list[str]:
        return list(set(self.molecule_names) - self._conflicting_molecules_names_set)

    def __run_solver(self, molecule_names, start_conflicts, goal_conflicts, cache: Optional[SatScheduleCache] = None) -> tuple[list[Matching], dict[str, Matching], SatScheduleCache]:
        key = self._component_key(molecule_names)
        if cache is not None and key in cache:
            order_ids = cache[key]
            id_to_name = {v: k for k, v in self._mol_id.items()}
            return [self._by_name[id_to_name[i]] for i in order_ids], None, cache


        self.solver = Solver()

        num_molecules = len(molecule_names)
        self.__init_booleans(molecule_names, num_molecules)
        self.__distinct_goals()
        self.__conflicting_starts(start_conflicts, molecule_names, num_molecules)
        self.__conflicting_goals(goal_conflicts, molecule_names, num_molecules)

        self.solver.set(':core.minimize', True)
        res = self.solver.check()
        if res == sat:
            schedule = self.__process_model(molecule_names, num_molecules)
            if cache is not None:
                cache[key] = [self._mol_id[m.molecule.name] for m in schedule]
            return schedule, None, cache
        else:
            return None, self.__process_unsat_core(), cache



    def __process_model(self, molecule_names, num_molecules) -> list[Matching]:
        model = self.solver.model()
        schedule = sorted(
            ((self.__find_matching(molecule_names[i]), model[self.vars[i]].as_long()) for i in range(len(self.vars))),
            key=lambda kv: kv[1]
        )
        schedule = [s[0] for s in schedule]
        return schedule

    def __process_unsat_core(self) -> dict[str, Matching]:
        #logger.info("Problem is UNSAT.")
        core = self.solver.unsat_core()
        assert len(core) > 0, "Unsat Core is empty"
        conflicts = dict()
        for assumption in core:
            split = str(assumption).split("-")
            objects = split[1:]
            conflicts[assumption] = self.__find_matching(objects[0])
        return conflicts

    def __init_booleans(self, molecules_names, num_molecules) -> None:
        self.vars = [Int(f"G_{name}") for name in molecules_names]

    def __distinct_goals(self) -> None:
        for var in self.vars:
            self.solver.add(And(1 <= var, var <= len(self.vars)))
        self.solver.add(Distinct(self.vars))

    def __conflicting_starts(self, start_conflicts, molecule_names, num_molecules) -> None:
        for match, conflict in start_conflicts:
            match_index = molecule_names.index(match)
            conflict_index = molecule_names.index(conflict)
            self.solver.assert_and_track(self.vars[conflict_index] < self.vars[match_index], f"start-{match}-{conflict}")

    def __conflicting_goals(self, start_conflicts, molecule_names, num_molecules) -> None:
        for match, conflict in start_conflicts:
            match_index = molecule_names.index(match)
            conflict_index = molecule_names.index(conflict)
            self.solver.assert_and_track(self.vars[match_index] < self.vars[conflict_index], f"goal-{match}-{conflict}")



    def __compute_conflicts(self) -> None:
        self.goal_conflicts = []
        self.start_conflicts = []

        mol_polys  = [m.polygon for m in self.molecules]
        mol_names  = [m.name    for m in self.molecules]
        mol_tree   = STRtree(mol_polys)

        goal_polys = [m.goal.polygon for m in self.matching]
        goal_names = [m.molecule.name for m in self.matching]
        goal_tree  = STRtree(goal_polys)

        mol_index_by_obj  = {id(m): i for i, m in enumerate(self.molecules)}
        goal_index_by_obj = {id(m.goal): i for i, m in enumerate(self.matching)}

        for i, matching in enumerate(self.matching):
            mover_name = matching.molecule.name
            corridor = compute_corridor(
                matching,
                corridor_width=self.corridor_config.width,
                parking_buffer=self.corridor_config.parking_buffer,
                parking_distance=self.corridor_config.parking_distance,
            )
            prepared_corridor = _prepare_geom(corridor)

            cand_idx = mol_tree.query(corridor)  # ndarray of indices
            self_idx = mol_index_by_obj[id(matching.molecule)]
            if cand_idx.size:
                cand_idx = cand_idx[cand_idx != self_idx]
                if cand_idx.size:
                    hits_mask = intersects(corridor, [mol_polys[j] for j in cand_idx])
                    if hasattr(hits_mask, "__iter__"):
                        for j in cand_idx[hits_mask]:
                            self.start_conflicts.append((mover_name, mol_names[j]))
                    else:
                        for j in cand_idx:
                            if prepared_corridor.intersects(mol_polys[j]):
                                self.start_conflicts.append((mover_name, mol_names[j]))

            cand_idx = goal_tree.query(corridor)
            self_goal_idx = goal_index_by_obj[id(matching.goal)]
            if cand_idx.size:
                cand_idx = cand_idx[cand_idx != self_goal_idx]
                if cand_idx.size:
                    hits_mask = intersects(corridor, [goal_polys[j] for j in cand_idx])
                    if hasattr(hits_mask, "__iter__"):
                        for j in cand_idx[hits_mask]:
                            self.goal_conflicts.append((mover_name, goal_names[j]))
                    else:
                        for j in cand_idx:
                            if prepared_corridor.intersects(goal_polys[j]):
                                self.goal_conflicts.append((mover_name, goal_names[j]))

        if self.start_conflicts:
            self.start_conflicts = list(dict.fromkeys(self.start_conflicts))
        if self.goal_conflicts:
            self.goal_conflicts  = list(dict.fromkeys(self.goal_conflicts))

    def __separate_conflicts(self) -> list[tuple]:
        all_conflicts = [set(conflict) for conflict in self.start_conflicts + self.goal_conflicts]
        same_length = False
        split_conflicts = list()
        while not same_length:
            before = len(all_conflicts)
            all_conflicts, same_length = self.__merge(all_conflicts)

        #logger.info(all_conflicts)
        for merged_conflicts in all_conflicts:
            start_conflicts = list()
            goal_conflicts = list()
            for name in merged_conflicts:
                for start_conflict in self.start_conflicts:
                    if name == start_conflict[0]:
                        start_conflicts.append(start_conflict)
                for goal_conflict in self.goal_conflicts:
                    if name == goal_conflict[0]:
                        goal_conflicts.append(goal_conflict)
            split_conflicts.append((merged_conflicts, start_conflicts, goal_conflicts))
        return split_conflicts


    def __merge(self, all_conflicts) -> tuple[list[set], bool]:

        try:
            merged = [all_conflicts[0]]
            for conflicts in all_conflicts[1:]:
                intersects = False
                for merged_set in merged:
                    if len(conflicts & merged_set) > 0:
                        merged_set.update(conflicts)
                        intersects = True
                if not intersects:
                    merged.append(conflicts)
            return merged, len(all_conflicts) == len(merged)
        except IndexError:
            return all_conflicts, True

    def __find_matching(self, molecule_name: str) -> Matching:
        try:
            return self._by_name[molecule_name]
        except KeyError as e:
            log_and_raise(e, f"Tried to find molecule name {molecule_name=} in {[m.molecule.name for m in self.matching]}")

    def _component_key(self, molecule_names: list[str]) -> tuple:
        # key = sorted tuple of (mol_id, goal_id)
        pairs = [(self._mol_id[n],
                  self._goal_id[(float(self._by_name[n].goal.position.x),
                                 float(self._by_name[n].goal.position.y))])
                 for n in molecule_names]
        pairs.sort()
        return tuple(pairs)   # hashable, tiny

class TopologicalConflictGraph():
    def __init__(self,
                 conflicting_starts: list[tuple[str, str]],
                 conflicting_goals:  list[tuple[str, str]],
                 ) -> None:
        self.conflicting_starts = conflicting_starts
        self.conflicting_goals  = conflicting_goals
        self.molecule_names  = set([m[0] for m in self.conflicting_starts])
        self.molecule_names |= set([m[0] for m in self.conflicting_goals])
        self.molecule_names |= set([m[1] for m in self.conflicting_starts])
        self.molecule_names |= set([m[1] for m in self.conflicting_goals])
        self.__create_graph()

    def __create_graph(self):
        self.G = nx.MultiDiGraph()
        self.G.add_nodes_from(self.molecule_names)
        for conflicting_start in self.conflicting_starts:
            self.G.add_edge(conflicting_start[1], conflicting_start[0], color="r", type="start")
        for conflicting_goal in self.conflicting_goals:
            self.G.add_edge(conflicting_goal[1], conflicting_goal[0], color="b", type="goal")

    def get_topological_clashes(self) -> list[NestedClash]:
        clashes = list()
        for node in self.G.nodes():
            clashes.extend(self._get_multiplicity_2_neighbors(node))
        return clashes

    def get_nested_clashes(self) -> list[NestedClash]:
        topological_clashes = self.get_topological_clashes()
        topological_clashes_inner = [tc.inner for tc in topological_clashes]
        problematic_clashes = list()
        for nested_clash in topological_clashes:
            if nested_clash.outer in topological_clashes_inner:
                logger.info(f"{nested_clash=} is problematic")
                problematic_clashes.append(nested_clash)
        return problematic_clashes

    def _get_multiplicity_2_neighbors(self, node) -> list[NestedClash]:
        clashes = list()
        for neighbor in self.G.neighbors(node):
            if self.G.number_of_edges(node,neighbor)>1:
                clashes.append(NestedClash(neighbor, node))
                logger.trace(f"Containment: {clashes[-1].inner} - {clashes[-1].outer}")
        return clashes

    def show_graph(self, graph_storage):
        plt.clf()
        ax = plt.gca()
        draw_labeled_multigraph(self.G, "type", ax)
        matplotlib.use('Agg')
        plt.savefig(os.path.join(graph_storage, f"{datetime.now().isoformat()}_SATBasedScheduling_Graph.png"))

def draw_labeled_multigraph(G, attr_name, ax=None):
    import itertools as it
    """
    Length of connectionstyle must be at least that of a maximum number of edges
    between pair of nodes. This number is maximum one-sided connections
    for directed graph and maximum total connections for undirected graph.
    """
    # Works with arc3 and angle3 connectionstyles
    connectionstyle = [f"arc3,rad={r}" for r in it.accumulate([0.15] * 4)]
    # connectionstyle = [f"angle3,angleA={r}" for r in it.accumulate([30] * 4)]

    pos = nx.circular_layout(G)
    colors = list()
    for (u,v,attrib_dict) in list(G.edges.data()):
        colors.append(attrib_dict['color'])
    nx.draw_networkx_nodes(G, pos, ax=ax)
    nx.draw_networkx_labels(G, pos, font_size=10, ax=ax)
    nx.draw_networkx_edges(
        G, pos, edge_color=colors, connectionstyle=connectionstyle, ax=ax
    )

    labels = {
        tuple(edge): f"{attr_name}={attrs[attr_name]}"
        for *edge, attrs in G.edges(keys=True, data=True)
    }
    labels = {}
    nx.draw_networkx_edge_labels(
        G,
        pos,
        labels,
        connectionstyle=connectionstyle,
        label_pos=0.3,
        font_color="blue",
        bbox={"alpha": 0},
        ax=ax,
    )
