import numpy as np
import math
from copy import deepcopy

from typing import Optional

from shapely import Point
import shapely

from molecule_movement import Molecule, MoleculeExperiment, VerticalAction, LateralAction, Goal, Movement

from loguru import logger
from .logging import log_and_raise


class Simulator():
    def __init__(self, molecules: list[Molecule] | list[MoleculeExperiment]  , goals: Optional[list[Goal]] = None, cell_size: float = 10.0, max_action_radius: float = 4.0) -> None:
        self.molecules: list[Molecule] = list(molecules)
        self.cell_size = float(cell_size)
        self.max_action_radius = float(max_action_radius)

        self._grid: dict[tuple[int, int], set[Molecule]] = {}
        self._mol_cell: dict[Molecule, tuple[int, int]] = {}

        self._build_spatial_hash()
        self.goals = goals

        self.stats_logger = logger.bind(task="stats")

    def perform_action(self, action: VerticalAction | LateralAction, current_goal_position: Optional[Point] = None) -> dict[Molecule, tuple[Point, Movement]]:
        moved_molecules = dict()
        action_world = Point(action.xy.x, action.xy.y)

        candidate_molecules = self._candidate_molecules(action_world)
        for molecule in candidate_molecules:
            position = deepcopy(molecule.center)
            action_xy = shapely.affinity.rotate(Point(action.xy.x - position.x, action.xy.y - position.y), -molecule.orientation, origin=(0,0))
            if isinstance(action, VerticalAction):
                molecule_action = VerticalAction(Point(round(action_xy.x,1), round(action_xy.y,1)), action.z, action.V)
                old_center = molecule.center
                movement = molecule.move(molecule_action)
            else:
                current_goal_position = molecule.center ## FIXXME
                if current_goal_position:
                    action_xy_dest = shapely.affinity.rotate(Point(action.xy_dest.x - current_goal_position.x, action.xy_dest.y - current_goal_position.y), 0, origin=(0,0))
                else:
                    action_xy_dest = action.xy_dest
                molecule_action = LateralAction(Point(round(action_xy.x,1), round(action_xy.y,1)), Point(round(action_xy_dest.x,1), round(action_xy_dest.y,1)), action.z, action.V)
                old_center = molecule.center
                movement = molecule.move(molecule_action)
            #self.stats_logger.bind(action=molecule_action, movement=movement,molecule=str(name)).trace(f"Moved {name} by {movement}")
            if movement.moved(0.001):
                moved_molecules[molecule] = (position, movement)
            self._update_molecule_cell(molecule, old_center)
        return moved_molecules



    def get_goals(self):
        return self.goals if self.goals else list()

    def get_molecules(self):
        return self.molecules

    def get_moved_molecules(self):
        return self.moved_map

    def _cell_for_point(self, p: Point) -> tuple[int, int]:
        """Map a world-space point to integer cell coordinates."""
        cs = self.cell_size
        return (int(math.floor(p.x / cs)), int(math.floor(p.y / cs)))

    def _build_spatial_hash(self) -> None:
        """Initialize grid from all current molecules."""
        self._grid.clear()
        self._mol_cell.clear()
        for mol in self.molecules:
            self._insert_molecule_into_grid(mol)

    def _insert_molecule_into_grid(self, mol: Molecule) -> None:
        cell = self._cell_for_point(mol.center)
        self._grid.setdefault(cell, set()).add(mol)
        self._mol_cell[mol] = cell

    def _remove_molecule_from_grid(self, mol: Molecule) -> None:
        cell = self._mol_cell.pop(mol, None)
        if cell is None:
            return
        cell_set = self._grid.get(cell)
        if cell_set is not None:
            cell_set.discard(mol)
            if not cell_set:
                # optional: keep grid sparse
                del self._grid[cell]

    def _update_molecule_cell(self, mol: Molecule, old_center: Point) -> None:
        """
        Call this after mol.center has potentially changed.

        Uses cached cell in _mol_cell, but falls back to old_center
        if that mapping wasn't set for some reason.
        """
        old_cell = self._mol_cell.get(mol)
        if old_cell is None:
            old_cell = self._cell_for_point(old_center)

        new_cell = self._cell_for_point(mol.center)

        if new_cell == old_cell:
            # still in same cell, nothing to do
            self._mol_cell[mol] = new_cell
            return

        # remove from old cell
        cell_set = self._grid.get(old_cell)
        if cell_set is not None:
            cell_set.discard(mol)
            if not cell_set:
                del self._grid[old_cell]

        # add to new cell
        self._grid.setdefault(new_cell, set()).add(mol)
        self._mol_cell[mol] = new_cell

    def _candidate_molecules(self, world_xy: Point) -> set[Molecule]:
        """
        Return molecules in cells within max_action_radius of world_xy.
        """

        if len(self.molecules) == 1:
            return self.molecules
        r = self.max_action_radius
        cs = self.cell_size

        x_min = world_xy.x - r
        x_max = world_xy.x + r
        y_min = world_xy.y - r
        y_max = world_xy.y + r

        ix_min = int(math.floor(x_min / cs))
        ix_max = int(math.floor(x_max / cs))
        iy_min = int(math.floor(y_min / cs))
        iy_max = int(math.floor(y_max / cs))

        candidates: set[Molecule] = set()
        for ix in range(ix_min, ix_max + 1):
            for iy in range(iy_min, iy_max + 1):
                cell_set = self._grid.get((ix, iy))
                if cell_set:
                    candidates.update(cell_set)

        # optional: tighten by actual distance if your action radius is sharp
        # candidates = {
        #     m for m in candidates
        #     if m.center.distance(world_xy) <= r
        # }

        return candidates

class HardwareSimulator(Simulator):
    def __init__(self, molecules: list[Molecule] | list[MoleculeExperiment]  , goals: Optional[list[Goal]] = None) -> None:
        super().__init__(molecules, goals)
        self.current_image_file = None


    @property
    def overview_image_path(self) -> str:
        return ""

    def perform_vertical_manipulation(self,
                                      x_position_nm: int,
                                      y_position_nm: int,
                                      z_approach_nm: int,
                                      voltage_mV: int):
        action = VerticalAction(Point(x_position_nm,y_position_nm),
                                                  voltage_mV,
                                                  z_approach_nm)
        #logger.info(f"Performing {action=} in HardwareSimulator")
        return self.perform_action(action)

    def perform_lateral_manipulation(self,
                                     x_start_position_nm: int,
                                     y_start_position_nm: int,
                                     x_end_position_nm: int,
                                     y_end_position_nm: int,
                                     voltage_mV: int,
                                     z_position_nm: int):
        action = LateralAction(Point(x_start_position_nm,y_start_position_nm),
                                                 Point(x_end_position_nm, y_end_position_nm),
                                                 voltage_mV,
                                                 z_position_nm)
        #logger.info(f"Performing {action=} in HardwareSimulator")
        return self.perform_action(action)

    def _get_rough_lat_position_for_exact_search(self) -> Point:
        return Point(0,0)


    def set_filename_per_timestep(self, filename_per_timestep: str) -> None:
        pass

    def scan_topography(self, center_nm: Point, topography_size_nm: int = 4, number_of_topography_points: int = 128):
        pass
