from typing import Optional
import numpy as np
from scipy.optimize import linear_sum_assignment

from loguru import logger
from molecule_movement.logging import log_and_raise

from molecule_movement.matching.AbstractMatching import AbstractMatching
from molecule_movement import Matching, Goal, Obstacle, Molecule, MoleculeExperiment, Pair

from molecule_movement.exceptions import InfeasibleError

from molecule_movement.wrapper.CorridorWrapper import CorridorConfiguration

from pytictoc import TicToc

class HungarianMatching(AbstractMatching):
    def __init__(self, molecules: list[Molecule] | list[MoleculeExperiment],
                 goals: list[Goal],
                 obstacles: Optional[list[Obstacle]] = None,
                 respect_obstacles: bool = True,
                 circumvent_obstacles: bool = True,
                 check_positioned: bool = False,
                 corridor_config: Optional[CorridorConfiguration] = None):
        super().__init__(molecules, goals, obstacles, respect_obstacles, circumvent_obstacles=circumvent_obstacles, corridor_config=corridor_config, check_positioned=check_positioned)
        self._compute_distance_matrix()

    def compute_matching(self, seed: Optional[int] = None) -> list[Matching]:
        logger.trace(f"Computing matching using {self.cls}")
        self.matching = list()
        try:
            row_ind, col_ind = linear_sum_assignment(self.distance_matrix)
        except ValueError as e:
            logger.info(f"\n{self.distance_matrix}")
            inf_cols = np.where(np.all(self.distance_matrix[:,0:] == np.inf, axis=0))
            inf_rows = np.where(np.all(self.distance_matrix[:,0:] == np.inf, axis=1))
            try:
                reason = f"Cannot match molecules {[m.name for i, m in enumerate(self.molecules) if i in inf_cols[0]]} and goals {[(round(g.position.x,2), round(g.position.y,2)) for i, g in enumerate(self.goals) if i in inf_rows[0]]}"
            except Exception as f:
                reason = f"Could not resolve reason: {f}."
                pass
            logger.error(f"Could not provide a matching: {e}\n{reason}")
            raise InfeasibleError(f"{e}. {reason}")
        self.matching_molecule_indices = row_ind
        self.matching_goal_indices = col_ind
        for row, col in zip(row_ind, col_ind):
            self.matching.append(self.molecule_goal_paths[Pair(self.molecules[row], self.goals[col])])

        return self.matching

    def get_matching_moiety_indices(self) -> np.ndarray:
        return self.matching_molecule_indices

    def get_matching_goal_indices(self) -> np.ndarray:
        return self.matching_goal_indices
