import numpy as np
import math

from numpy.typing import NDArray
import gymnasium as gym

import string
from typing import Optional, SupportsFloat, Callable
from abc import ABC, abstractmethod
from typing import Type

from loguru import logger
from ..logging import log_and_raise, deprecated, warn
from ..colour_utils import Highlight

from molecule_movement import Molecule, MoleculeExperiment, Matching, VerticalAction, LateralAction, Movement
from molecule_movement.Simulator import Simulator
from molecule_movement.PyGameRenderer import PyGameRenderer
from molecule_movement.matching import RandomMatching
from molecule_movement.constants import SystemConfig

from shapely import Point
import shapely

A, Z = np.array(["A","Z"]).view("int32")
RANDOM_NAME_LEN = 3

MAX_SENSOR_DISTANCE = 50

class MoleculeEnvironment(ABC, gym.Env):
    metadata = {
        "render_modes": ["none", "human", "rgb_array"],
    }
    CRASH_DISTANCE = 0.0
    SUCCESS_DISTANCE = 0.3
    def __init__(
            self,
            system_configuration: Type[SystemConfig] | list[SystemConfig],
            render_mode: str = "human",
            scale: int = 5,
            render_grid: bool = True,
            draw_names: bool = True,
            render_sensors: str = "all",
            surface_width: int = 200,
            surface_height: int = 200,
            num_sensors: int = 16,
            max_steps: int = 400,
            origin_offset: tuple[int, int] = (0,0),
            **kwargs
            ):
        self.steps = 1
        self.max_steps = max_steps
        self.render_mode = render_mode

        self.surface_height, self.surface_width = surface_height, surface_width
        self.window_size = scale * np.array((self.surface_width, self.surface_height))
        self.render_grid = render_grid
        self.render_sensors = render_sensors
        self.draw_names = draw_names
        self.origin_offset = origin_offset

        if isinstance(system_configuration, list):
            configurations = dict()
            for a, b in zip(system_configuration, system_configuration[1:]):
                assert a.substrate == b.substrate, f"Cannot run experiment for SystemConfigurations of unequal substrate: {a.substrate.type} != {b.substrate.type}"
                configurations[a.moiety.type] = a
            configurations[system_configuration[-1].moiety.type] = system_configuration[-1]
            self.system_configuration = configurations
            self.substrate_point_symmetry = system_configuration[0].substrate_point_symmetry
        else:
            self.system_configuration = {system_configuration.moiety.type: system_configuration}
            self.substrate_point_symmetry = system_configuration.substrate_point_symmetry

        self.num_sensors = num_sensors

        self.vertical_action_space = gym.spaces.Box(np.zeros((4,)), np.full((4,), [surface_width, surface_height, 0, 0]))
        self.lateral_action_space = gym.spaces.Box(np.zeros((6,)), np.full((6,), [surface_width, surface_height, surface_width, surface_height, 0, 0]))
        self.action_space = gym.spaces.Dict({"vertical": self.vertical_action_space, "lateral": self.lateral_action_space})

        self.observation_space = gym.spaces.Sequence(gym.spaces.Text(9,charset=string.printable))

        #self._parse_molecule_data()
        self.renderer = None

    def _parse_molecule_data(self) -> None:
        pass

    @abstractmethod
    def _create_initial_distribution(self, seed: Optional[int] = None) -> None:
        pass

    @abstractmethod
    def _set_goals(self, seed: Optional[int] = None) -> None:
        pass

    @abstractmethod
    def _set_obstacles(self, seed: Optional[int] = None) -> None:
        if not hasattr(self, "obstacles"):
            self.obstacles = list()

    @abstractmethod
    def _get_matching(self, options: Optional[dict] = None, seed: Optional[int] = None) -> None:
        pass

    def _set_matching(self, matching: list[Matching]) -> None:
        self.matching = matching

    def increment_matching(self) -> tuple[dict[str, NDArray], dict]:
        self.steps = 0
        self.current_matching_index += 1
        try:
            while self.current_matching_index < len(self.matching) and self.current_molecule.crashed:
                logger.info("Skipping crashed molecule")
                self.current_matching_index += 1
            if self.current_matching_index == len(self.matching):
                len_matching = len(self.matching)
                if self.renderer:
                    self.render(surpress_matching=True, surpress_goals=True)
                log_and_raise(StopIteration("Maximum number of matchings reached."), f"Incremented the matching index to {self.current_matching_index} which exceeds the number of matchings {len_matching}.", log=False)
        except IndexError:
            len_matching = len(self.matching)
            if self.renderer:
                self.render(surpress_matching=True, surpress_goals=True)
            log_and_raise(StopIteration("Maximum number of matchings reached."), f"Incremented the matching index to {self.current_matching_index} which exceeds the number of matchings {len_matching}.", log=False)
        if self.current_matching_index == len(self.matching):
            len_matching = len(self.matching)
            if self.renderer:
                self.render(surpress_matching=True, surpress_goals=True)
            log_and_raise(StopIteration("Maximum number of matchings reached."), f"Incremented the matching index to {self.current_matching_index} which exceeds the number of matchings {len_matching}.", log=False)
        else:
            return self.observation(), {}


    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None) -> tuple[dict[str, NDArray], dict]:
        if options:
            self.options = options
        self.seed = seed
        self.steps = 0
        self.name_generator = np.random.default_rng(seed=seed)
        self.current_matching_index = 0

        self._set_goals(seed)
        self._set_obstacles(seed)
        self._create_initial_distribution(seed=seed)
        self.name_to_molecule_map = {molecule.name : molecule for molecule in self.molecules}
        self.simulator = Simulator(self.molecules, self.goals)
        self._initialize_renderer(origin_offset=self.origin_offset, draw_names=self.draw_names)
        if self.renderer:
            self.renderer.clear(force=True)
            self.render(None)
            self.renderer.update()
        logger.trace("compute matching in reset")
        self._get_matching(seed=seed, options=options)

        super().reset(seed=seed)

        if self.renderer:
            self.render(None)
            self.renderer.update()
        return self.observation(), {}

    def step(self, action: VerticalAction | LateralAction):
        if self.molecules is None or len(self.molecules) == 0:
            logger.warning("The list of molecules is empty, need to call env.reset() first")
            assert False, "The list of molecules is empty, need to call env.reset() first"
        self.steps += 1
        info = dict()

        if action.shape == (2,):
            action = VerticalAction(xy=Point(*action))
        elif action.shape == (4,):
            action = VerticalAction(xy=Point(action[0], action[1]), z=action[2], V=action[3])
        elif action.shape == (6,):
            action = LateralAction(xy=Point(action[0], action[1]), xy_dest=Point(action[2], action[3]), z=action[4], V=action[5])
        else:
            msg=f"Cannot handle actions of the form {action.shape=} (e.g.: {action=}) yet"
            log_and_raise(ValueError("Cannot interpret action"), msg)

        self.distance_before_action = self.current_molecule.center.distance(self.current_matching.goal.position)
        self.current_molecule_position_before = self.current_molecule.center
        self.current_molecule_rotation_before = self.current_molecule.orientation
        moved_molecules = self.simulator.perform_action(action, self.current_matching.goal.position)
        for m in moved_molecules:
            self._crashed(m)

        self.render(action=action, moved_molecules=moved_molecules)
        obs = self.observation()
        terminated, reason = self.terminated()
        info.update(reason)

        logger.bind(task="stats",
                    distance=self.current_molecule.center.distance(self.current_matching.goal.position),
                    molecule_x_before=self.current_molecule_position_before.x,
                    molecule_y_before=self.current_molecule_position_before.y,
                    molecule_rotation_before=self.current_molecule_rotation_before,
                    goal=self.current_matching.goal).trace("")

        return obs, self.reward(), terminated, self.truncated(), info

    def render(self, action: Optional[VerticalAction|LateralAction] = None, moved_molecules: Optional[dict[Molecule, tuple[Point, Movement]]] = None, matching_highlight: Optional[dict[Matching, Highlight]] = None, surpress_matching: bool = False, surpress_goals: bool = False):
        if self.render_mode == "human" or self.render_mode == "rgb_array":
            self.renderer.clear()
            if isinstance(action, VerticalAction) or isinstance(action, LateralAction):
                self.renderer.render_STM(action)
            self.renderer.render_movements(moved_molecules)
            self.renderer.render_background(surpress_goals)
            self.renderer.render_molecules(current_molecule=self.current_molecule if not surpress_matching and hasattr(self, "matching") and self.matching is not None else None)
            self.renderer.update()

    def get_image(self):
        if self.render_mode is None or self.render_mode == "none":
            assert self.spec is not None
            warn(
                "You are calling get_image method without specifying any render mode. "
                "You can specify the render_mode at initialization"
            )
            return
        return self.renderer.get_image()

    def observation(self):
        return tuple(m.name for m in self.molecules)

    def terminated(self):
        info = {"crashed": False, "reached_goal": False, "reached_goal_orientation": False, "destroyed": False}
        if self.current_molecule.crashed:
            info["crashed"] = True
        if self.reached_goal_position():
            info["reached_goal"] = True
        if self.reached_goal_position() and self.reached_goal_orientation():
            info["reached_goal_orientation"] = True
        logger.bind(task="stats", crashed=int(info["crashed"])).trace("")
        logger.bind(task="stats", reached=int(info["reached_goal"])).trace("")
        logger.bind(task="stats", reached_orientation_at_goal=int(info["reached_goal_orientation"])).trace("")
        return self.current_molecule.crashed or (self.reached_goal_position() and self.reached_goal_orientation()), info

    def truncated(self) -> bool:
        return self.steps >= self.max_steps

    def _crashed(self, molecule: Molecule) -> bool:
        molecule_distances = np.asarray([molecule.polygon.distance(m.polygon) if m != molecule else np.inf for m in self.molecules])
        obstacle_distances = np.asarray([molecule.polygon.distance(m.polygon) if m != molecule else np.inf for m in self.obstacles])
        crashed_into_molecule = bool(np.any(molecule_distances <= self.CRASH_DISTANCE))
        crashed_into_obstacles = bool(np.any(obstacle_distances <= self.CRASH_DISTANCE))
        if crashed_into_molecule or crashed_into_obstacles:
            molecule.set_crashed()
            if crashed_into_molecule:
                for molecule_index in np.where(molecule_distances <= self.CRASH_DISTANCE)[0]:
                        self.molecules[molecule_index].set_crashed()
        return crashed_into_obstacles or crashed_into_molecule

    def reached_goal_position(self) -> bool:
        return self.current_distance < self.SUCCESS_DISTANCE

    def reached_goal_orientation(self) -> bool:
        return np.abs((int(self.current_molecule.orientation) - int(self.current_matching.goal.rotation)) % self.current_molecule.angle_symmetry._symmetry_angle_moiety) == 0

    def reward(self) -> float:
        return 0

    def _initialize_renderer(self, origin_offset: tuple[float, float] = (0,0), flip_y: bool = True, flip_x: bool = False, draw_names: bool = True) -> None:
        if not self.render_mode == "human":
            return
        #if self.renderer is None:
        self.renderer = PyGameRenderer(self.simulator, obstacles=self.obstacles, environment_name=self.spec.id, window_size=self.window_size, size=(self.surface_width, self.surface_height), render_grid=self.render_grid, render_sensors=self.render_sensors, origin_offset=origin_offset, flip_y=flip_y, flip_x=flip_x, draw_names=draw_names, store_images=True)
        #else:
        #    self.renderer.set_simulator(self.simulator)
        self.renderer.clear_movement()

    def _sample_random_name(self, seed: Optional[int] = None) -> str:
        return self.name_generator.integers(low=A,high=Z,size=RANDOM_NAME_LEN,dtype="int32").view(f"U{RANDOM_NAME_LEN}")[0]

    def _sample_n_random_names(self, size: int, seed: Optional[int] = None) -> list[str]:
        """
        We store the names as keys in a dict to preserve order: https://stackoverflow.com/a/53657523
        """
        names = set()
        while len(names) < size:
            names.add(self._sample_random_name())
        return sorted(list(names))
        #return list(dict.fromkeys(names))

    def _overlaps_obstacle_func(self, min_distance: float = 0.0) -> Callable[[Point], bool]:
        def molecule_overlaps(p: Point) -> bool:
            return bool(np.any([p.distance(obstacle.polygon) <= min_distance for obstacle in self.obstacles]))
        return molecule_overlaps

    def _overlaps_goal_func(self, min_distance: float = 0.0) -> Callable[[Point], bool]:
        def molecule_overlaps(p: Point) -> bool:
            return bool(np.any([p.distance(obstacle.polygon) <= min_distance for obstacle in self.goals]))
        return molecule_overlaps


    def get_molecules(self) -> list[Molecule]:
        return self.molecules

    @property
    def current_molecule(self) -> Molecule | MoleculeExperiment:
        try:
            return self.matching[self.current_matching_index].molecule
        except IndexError as e:
            log_and_raise(e, f"Could not get current matching: {e}")
        except TypeError as e:
            log_and_raise(e, f"Need to compute matching first: {e}")

    @property
    def current_matching(self) -> Matching:
        try:
            return self.matching[self.current_matching_index]
        except IndexError as e:
            log_and_raise(e, f"Could not get current matching: {e}")
        except TypeError as e:
            log_and_raise(e, f"Need to compute matching first: {e}")

    @property
    def current_distance(self) -> float:
        try:
            return self.current_matching.molecule.center.distance(self.current_matching.goal.position)
        except IndexError as e:
            log_and_raise(e, f"Could not compute current distance: {e}")

    @property
    def goal_distance(self) -> float:
        try:
            return self.current_matching.molecule.starting_position.distance(self.current_matching.goal.position)
        except IndexError as e:
            log_and_raise(e, f"Could not compute goal distance: {e}")

    @property
    def goal_position(self) -> Point:
        try:
            return self.current_matching.goal.position
        except IndexError as e:
            log_and_raise(e, f"Could not get goal position: {e}")

    @property
    def goal_orientation(self) -> float:
        try:
            return self.current_matching.goal.rotation
        except IndexError as e:
            log_and_raise(e, f"Could not get goal position: {e}")
