from abc import ABC, abstractmethod
from typing import Optional
from numpy.typing import NDArray

import numpy as np

from shapely import LineString
import shapely
from scipy.spatial.distance import cdist, pdist, squareform
from scipy.spatial import KDTree

from shapely.strtree import STRtree

from molecule_movement import Matching, Molecule, MoleculeExperiment, Goal, Obstacle, Pair

from molecule_movement.wrapper.CorridorWrapper import CorridorConfiguration
from molecule_movement.exceptions import AssumptionsError, InfeasibleError

from .ShortestPathObstacles import ShortestPathObstacles

from ..logging import log_and_raise, cond_log_and_raise

from loguru import logger

from pytictoc import TicToc

BUFFER_FACTOR = 0.5
BUFFER_JOIN_STYLE = shapely.BufferJoinStyle.mitre
LATTICE_CONSTANT = 0.3

class AbstractMatching(ABC):
    def __init__(self,
                 molecules: list[Molecule] | list[MoleculeExperiment],
                 goals: list[Goal],
                 obstacles: Optional[list[Obstacle]] = None,
                 respect_obstacles: bool = True,
                 circumvent_obstacles: bool = True,
                 check_positioned: bool = False,
                 corridor_config: Optional[CorridorConfiguration] = None):
        if not goals or len(goals) == 0:
            log_and_raise(ValueError("Tried to compute matching without specifying goals."), msg="Tried to compute matching without specifying goals.")
        self.molecules = molecules
        self.goals = goals
        self.obstacles = obstacles if obstacles else list()
        self.respect_obstacles = respect_obstacles
        self.circumvent_obstacles = circumvent_obstacles
        self.check_positioned = check_positioned
        self.matching = list()

        if corridor_config:
            self.corridor_config = corridor_config
            self.buffer_size = corridor_config.width * BUFFER_FACTOR
        else:
            self.corridor_config = None
            molecule = self.molecules[0]
            self.buffer_size = max(molecule.size_x, molecule.size_y)

        self.buffer_size += LATTICE_CONSTANT

        self.molecule_goal_paths = dict()
        self.inf_distances = list()

        if check_positioned:
            mol_mask, goal_mask = self.__positioned_masks(radius=0.1)
            positioned = [Obstacle(m.center, m.shape, int(m.orientation)) for m, keep in zip(molecules, mol_mask) if keep]
            self.obstacles.extend(positioned)
            self.molecules = [m for m, keep in zip(self.molecules, ~mol_mask) if keep]
            self.goals = [g for g, keep in zip(self.goals, ~goal_mask) if keep]

        if circumvent_obstacles and len(self.obstacles) > 0:
            self.shortest_paths = ShortestPathObstacles(self.molecules, self.goals, self.obstacles, self.buffer_size)

    @abstractmethod
    def compute_matching(self, seed: Optional[int] = None) -> list[Matching]:
        pass

    def _compute_distance_matrix(self) -> None:
        logger.trace(f"Computing distance matrix")
        t = TicToc()
        t.tic()
        distance_matrix = np.empty((len(self.molecules), len(self.goals)), dtype=np.float32, order='C')
        if self.corridor_config:
            assumption_violated, goal_violations, obstacle_violations, molecule_violations = self.__check_assumptions()
            if assumption_violated:
                logger.error(f"Cannot provide SAT-based schedule: Initial assumptions are violated for corridor width of {self.corridor_config.width}:\n" + f"{goal_violations=}" + f"{obstacle_violations}")
                raise AssumptionsError(f"Cannot provide SAT-based schedule: Initial assumptions are violated for corridor width of {self.corridor_config.width}:\n" +
                                       f"{goal_violations=}\n" +
                                       f"{obstacle_violations=}\n" +
                                       f"{molecule_violations=}\n")

        molecules = self.molecules
        goals = self.goals
        if not self.respect_obstacles or len(self.obstacles) == 0:
            dist = self.__compute_molecule_goal_distance_without_obstacles
        else:
            dist = self.__compute_molecule_goal_distance

        for i, mol in enumerate(molecules):
            distance_matrix[i, :] = [dist(mol, g) for g in goals]
        self.distance_matrix = distance_matrix
        time_needed = round(t.tocvalue(),3)
        logger.bind(task="scheduling", compute_distance_matrix_time_needed=time_needed).trace(f"Computing distance matrix took {time_needed} seconds.")

    def __compute_molecule_goal_distance_without_obstacles(self, molecule: Molecule, goal: Goal) -> float:
        pair = Pair(molecule, goal)
        self.molecule_goal_paths[pair] = Matching(*pair)
        if molecule.type != goal.type:
            return np.inf
        return molecule.center.distance(goal.position)

    def __compute_molecule_goal_distance(self, molecule: Molecule, goal: Goal) -> float:
        pair = Pair(molecule, goal)
        if molecule.type != goal.type:
            self.molecule_goal_paths[pair] = Matching(*pair)
            return np.inf
        try:
            matching = self.shortest_paths._compute_path(molecule, goal)
            self.molecule_goal_paths[pair] = matching
            return matching.length
        except TypeError as e:
            log_and_raise(e, msg=f"Could not compute molecule goal distance matrix: {e}.")
        except InfeasibleError as e:
            log_and_raise(e, msg=f"Could not compute molecule goal distance matrix: {e}.")

    def add_molecule(self, molecule: Molecule) -> None:
        self.molecules.append(molecule)
        distances = np.zeros((len(self.goals)), dtype=np.float64)
        for j, goal in enumerate(self.goals):
            distances[j] = self.__compute_molecule_goal_distance(molecule, goal)
        self.distance_matrix = np.vstack([self.distance_matrix, distances])

    def add_goal(self, goal: Goal) -> None:
        self.goals.append(goal)
        distances = np.zeros((len(self.molecules), 1), dtype=np.float64)
        for j, molecule in enumerate(self.molecules):
            distances[j][0] = self.__compute_molecule_goal_distance(molecule, goal)
        self.distance_matrix = np.hstack([self.distance_matrix, distances])

    def set_inf_weight(self, molecule: Molecule, goal: Goal) -> None:
        self.inf_distances.append((molecule, goal))
        self.distance_matrix[self.molecules.index(molecule)][self.goals.index(goal)] = np.inf

    def _flattened_index_to_col_row(self, index: np.intp) -> tuple[int, int]:
        return np.unravel_index(index, self.distance_matrix.shape)

    @property
    def cls(self) -> str:
        return type(self).__name__

    @property
    def length(self) -> float:
        cond_log_and_raise(self.matching, ValueError("Tried to compute length of matching, but matching is None"), msg="")
        return sum([m.length for m in self.matching])

    @property
    def statistics(self) -> str:
        logger.bind(task="stats",matching_length=self.length).trace(f"{self.cls} computed matching of total length: {self.length}")
        return f"{self.cls} computed matching of total length: {self.length}"


    def __check_assumptions(self) -> tuple[bool, list, list, list]:
        goal_assumption, goal_violations = self.__check_goal_clearance()
        obstacles_assumption, obstacle_violations = self.__check_obstacle_clearance()
        molecule_assumption, molecule_violations = self.__check_molecule_clearance()
        return goal_assumption or obstacles_assumption or molecule_assumption, goal_violations, obstacle_violations, molecule_violations

    def __check_molecule_clearance(self) -> tuple[bool, list[str]]:
        """
        Returns (has_violation, [molecule_name, ...]) where a violation means
        some molecule's *center* is closer than width/2 to another molecule's *polygon*.
        """
        r = self.corridor_config.width / 2.0
        molecules = self.molecules
        n = len(molecules)

        centers = [m.center for m in molecules]
        polys   = [m.polygon for m in molecules]

        tree = STRtree(polys)

        violating_names = []
        for i, center in enumerate(centers):
            cand_idx: np.ndarray = tree.query(center.buffer(r))  # ndarray[int]
            if cand_idx.size == 0:
                continue
            for j in cand_idx:
                if j == i:
                    continue
                distance = center.distance(polys[j])
                if distance < r:
                    violating_names.append(molecules[i].name)
                    break

        return (len(violating_names) > 0), violating_names


    def __check_goal_clearance(self) -> tuple[bool, Optional[list]]:
        """
        Returns (has_violation, [(goal.position, obstacle.position), ...] or None).
        A violation means goal.polygon is closer than width/2 to obstacle.polygon.
        """
        r = self.corridor_config.width / 2.0
        goals     = self.goals
        obstacles = self.obstacles

        goal_positions = [g.position for g in goals]
        obs_polygons  = [o.polygon for o in obstacles]

        tree = STRtree(obs_polygons)

        pairs = []
        for i, goal_position in enumerate(goal_positions):
            candidate_idx: np.ndarray = tree.query(goal_position.buffer(r))
            if candidate_idx.size == 0:
                continue
            for oj in candidate_idx:
                distance = goal_position.distance(obs_polygons[oj])
                if distance < r:
                    pairs.append((goal_position, obstacles[oj].position, f"{round(distance, 3)} < {round(r, 3)}"))

        return (True, pairs) if pairs else (False, None)


    def __check_obstacle_clearance(self) -> tuple[bool, Optional[list]]:
        """
        Returns (has_violation, [(molecule.name, obstacle.position, distance), ...] or None).
        A violation means molecule.center is closer than width/2 to obstacle.polygon.
        """
        r = self.corridor_config.width / 2.0
        molecules = self.molecules
        obstacles = self.obstacles

        centers  = [m.center for m in molecules]
        obs_polys = [o.polygon for o in obstacles]

        tree = STRtree(obs_polys)

        hits = []
        for i, center in enumerate(centers):
            cand_idx: np.ndarray = tree.query(center.buffer(r))
            if cand_idx.size == 0:
                continue
            for oj in cand_idx:
                distance = center.distance(obs_polys[oj])
                if distance < r:
                    hits.append((molecules[i].name, obstacles[oj].position, f"{round(distance, 3)} < {round(r, 3)}"))

        return (True, hits) if hits else (False, None)

    def __positioned_masks(self, radius: float):
        """
        Returns:
            positioned_molecules: (N,) bool — molecule i has >=1 goal within `radius`
            satisfied_goals:      (M,) bool — goal j has >=1 molecule within `radius`
        """
        N, M = len(self.molecules), len(self.goals)
        if N == 0 or M == 0:
            return np.zeros(N, bool), np.zeros(M, bool)

        mol_coords = np.array([m.center.coords[0] for m in self.molecules], dtype=np.float32)
        goal_coords = np.array([g.position.coords[0] for g in self.goals], dtype=np.float32)

        mol_tree = KDTree(mol_coords, balanced_tree=False, compact_nodes=True)
        hits = mol_tree.query_ball_point(goal_coords, r=radius, workers=-1)

        satisfied_goals = np.fromiter((len(h) > 0 for h in hits), count=M, dtype=bool)
        positioned_molecules = np.zeros(N, dtype=bool)
        for mol_idxs in hits:
            positioned_molecules[mol_idxs] = True

        return positioned_molecules, satisfied_goals
