from molecule_movement import Molecule, Goal, Obstacle, Pair, Matching, MoleculeExperiment
from molecule_movement.exceptions import InfeasibleError

import shapely
from shapely import Point, Polygon, GeometryCollection, LineString

import numpy as np

from scipy.sparse.csgraph import shortest_path
from scipy.sparse import csr_array
from scipy.spatial import distance_matrix
from shapely.strtree import STRtree
from shapely.prepared import prep
from scipy.spatial.distance import pdist, squareform

from loguru import logger
from molecule_movement.logging import log_and_raise

from pytictoc import TicToc

MOLECULE, GOAL, OBST_PT = 0, 1, 2

def _pack_points(molecules, goals, obstacle_points):
    n = len(molecules) + len(goals) + len(obstacle_points)
    coords = np.empty((n, 2), dtype=np.float64)
    types  = np.empty(n, dtype=np.int8)

    k = 0
    for m in molecules:
        coords[k] = (m.center.x, m.center.y)
        types[k]  = MOLECULE
        k += 1
    for g in goals:
        coords[k] = (g.position.x, g.position.y)
        types[k]  = GOAL
        k += 1
    for (x, y) in obstacle_points:
        coords[k] = (x, y)
        types[k]  = OBST_PT
        k += 1
    return coords, types


class ShortestPathObstacles():
    def __init__(self,
                 molecules: list[Molecule] | list[MoleculeExperiment],
                 goals: list[Goal],
                 obstacles: list[Obstacle],
                 buffer_size: float)-> None:
        self.molecules = molecules
        self.goals = goals
        self.obstacle_points = list()
        self.obstacles = obstacles
        for o in self.obstacles:
            o.set_buffer(o.polygon.buffer(buffer_size, shapely.BufferJoinStyle.bevel))
        self.all_obstacles_with_buffer = GeometryCollection([o.buffer for o in self.obstacles])
        self.merged_obstacles_with_buffer = shapely.union_all(self.all_obstacles_with_buffer)
        for o in self.all_obstacles_with_buffer.geoms:
            self.obstacle_points.extend(o.exterior.coords[:-1])

        t = TicToc()
        t.tic()
        self._compute_graph_points()
        time_needed = round(t.tocvalue(),3)
        logger.bind(task="scheduling", compute_graph_points_time_needed=time_needed).trace(f"Computing graph points took {time_needed} seconds.")
        try:
            t.tic()
            self._compute_shortest_paths()
            time_needed = round(t.tocvalue(),3)
            logger.bind(task="scheduling", compute_shortest_paths_time_needed=time_needed).trace(f"Computing shortest paths took {time_needed} seconds.")
        except Exception as e:
            log_and_raise(e, msg=f"Could not compute shortest circumventing paths: {e}.")

        self._mol_idx = {mol: i for i, mol in enumerate(self.molecules)}
        self._goal_idx = {g: j for j, g in enumerate(self.goals)}
        self._n_mol = len(self.molecules)
        self._point_cache = {}
        def _point_from_idx(idx, arr=self.all_objects, cache=self._point_cache):
            p = cache.get(idx)
            if p is None:
                x, y, _ = arr[idx]
                p = Point(x, y)
                cache[idx] = p
            return p
        self._point_from_idx = _point_from_idx

    def _compute_graph_points(self) -> None:
        logger.trace("Computing graph points")

        coords, types = _pack_points(self.molecules, self.goals, self.obstacle_points)
        self.all_objects = np.column_stack([coords, types.astype(np.float64, copy=False)])
        distance_matrix = squareform(pdist(coords, metric="euclidean"))

        same_type = types[:, None] == types[None, :]
        forbid_same = same_type & ((types[:, None] == MOLECULE) | (types[:, None] == GOAL))
        distance_matrix[forbid_same] = np.inf

        prepared_obs = prep(self.merged_obstacles_with_buffer)
        geoms = list(getattr(self.all_obstacles_with_buffer, "geoms", []))
        tree = STRtree(geoms) if geoms else None

        row_idx, col_idx = np.triu_indices(coords.shape[0], k=1)

        valid_pairs_mask = np.isfinite(distance_matrix[row_idx, col_idx])

        if valid_pairs_mask.any():
            check_row_idx = row_idx[valid_pairs_mask]
            check_col_idx = col_idx[valid_pairs_mask]
            A = coords[check_row_idx]
            B = coords[check_col_idx]
            segment_min = np.minimum(A, B)
            segment_max = np.maximum(A, B)

            obstacle_min_x, obstacle_min_y, obstacle_max_x, obstacle_max_y = self.merged_obstacles_with_buffer.bounds
            mn = np.minimum(A, B)
            mx = np.maximum(A, B)
            bbox_overlap_mask = ~(
                (segment_max[:, 0] < obstacle_min_x) | (segment_min[:, 0] > obstacle_max_x) |
                (segment_max[:, 1] < obstacle_min_y) | (segment_min[:, 1] > obstacle_max_y)
            )

            if bbox_overlap_mask.any():
                bbox_row_idx = check_row_idx[bbox_overlap_mask]
                bbox_col_idx = check_col_idx[bbox_overlap_mask]
                A3, B3 = A[bbox_overlap_mask], B[bbox_overlap_mask]

                segs = [LineString([a, b]) for a, b in zip(A3, B3)]

                if tree is not None:
                    segment_indices, obstacle_indices = tree.query(segs, predicate="intersects")

                    if len(segment_indices) > 0:
                        from collections import defaultdict
                        cand = defaultdict(list)
                        for segment_index, obstacle_index in zip(segment_indices.tolist(), obstacle_indices.tolist()):
                            cand[segment_index].append(obstacle_index)

                        for segment_index, obstacle_index_list in cand.items():
                            segment = segs[segment_index]
                            for obstacle_index in obstacle_index_list:
                                g = geoms[obstacle_index]
                                if segment.intersects(g) and not segment.touches(g):
                                    i, j = bbox_row_idx[segment_index], bbox_col_idx[segment_index]
                                    distance_matrix[i, j] = np.inf
                                    distance_matrix[j, i] = np.inf
                                    break
                else:
                    prepared_obs = prep(self.merged_obstacles_with_buffer)
                    for segment_index, segment in enumerate(segs):
                        if prepared_obs.intersects(segment) and not segment.touches(self.merged_obstacles_with_buffer):
                            i, j = bbox_row_idx[segment_index], bbox_col_idx[segment_index]
                            distance_matrix[i, j] = np.inf
                            distance_matrix[j, i] = np.inf

        self.distance_matrix = distance_matrix
        logger.trace("Computing graph points - DONE")

    def _compute_shortest_paths(self) -> None:
        logger.trace(f"Computing shortest paths")
        graph = csr_array(self.distance_matrix)
        self.path_distance_matrix, self.predecessors = shortest_path(csgraph=graph, directed=False, return_predecessors=True, indices=range(0,len(self.molecules)))
        logger.trace(f"Computing shortest paths - DONE")

    def _compute_path(self, molecule: Molecule, goal: Goal) -> Matching:
        mol_i  = self._mol_idx[molecule]
        goal_j = self._goal_idx[goal] + self._n_mol
        preds  = self.predecessors[mol_i]

        #if np.isinf(self.path_distance_matrix[mol_i][goal_j]):

        #    logger.warning(f"Tried to use matching that crosses an obstacle. Is your obstacle buffer (i.e. corridor_width) to large for the given assembly?")
        #    raise InfeasibleError("Tried to use matching that crosses an obstacle. Is your obstacle buffer (i.e. corridor_width) to large for the given assembly?")

        p = int(preds[goal_j])
        if p <= mol_i:
            return Matching(molecule, goal, [])

        path_idx = []
        while p > mol_i:
            path_idx.append(p)
            p = int(preds[p])

        points_on_path = [self._point_from_idx(k) for k in reversed(path_idx)]
        return Matching(molecule, goal, points_on_path)

    def compute_circumventing_matching(self, molecule, goal) -> Matching:
        return self._compute_path(molecule, goal)

    def __compute_distance(self, a, b):
        type_a = a[2]
        type_b = b[2]

        if type_a == type_b and (type_a == Goal or type_a == Molecule):
            return np.inf
        vector = LineString([(a[0],a[1]), (b[0], b[1])])
        if vector.intersects(self.merged_obstacles_with_buffer) and not vector.touches(self.merged_obstacles_with_buffer):
            return np.inf
        return vector.length


