import gymnasium as gym
import numpy as np

import os
from typing import Optional, Type

from loguru import logger

import shapely
from shapely import Point

from molecule_movement.envs.MoleculeEnvironment import MoleculeEnvironment
from molecule_movement import Molecule, Matching, Goal, AngleSymmetry
from molecule_movement.parsing import FixedVoltageDataParser, MoleculeDataProcessor, MockUpData
from molecule_movement.sampling import Sampler
from molecule_movement.shapes import DDNB, TRIANGLE, RECTANGLE, resize, FePc


class PositioningTrainingEnv(MoleculeEnvironment):
    def __init__(
            self,
            molecule_transition_data_x: str | os.PathLike,
            molecule_transition_data_y: str | os.PathLike,
            molecule_rotation_data: str | os.PathLike,
            **kwargs
            ):
        self.molecule_transition_data_x = molecule_transition_data_x
        self.molecule_transition_data_y = molecule_transition_data_y
        self.molecule_rotation_data     = molecule_rotation_data
        super().__init__(**kwargs)

    def _parse_molecule_data(self):
        self.data_processor = MoleculeDataProcessor(FixedVoltageDataParser(self.molecule_transition_data_x, self.molecule_transition_data_y, self.molecule_rotation_data, voltage=1700, name="DDNB"),
                                                    dimensions_x=(-2.1, 2.1),
                                                    dimensions_y=(-2.1, 2.1),
                                                    step_x=0.3,
                                                    step_y=0.3)
        self.molecule_transition_data = self.data_processor.get_molecular_data()

    def _create_initial_distribution(self, seed: Optional[int] = None):
        self.molecules = list()

        sampler = Sampler(lambda x: 1/3<x and x<2/3,
                          lambda y: 1/3<y and y<2/3,
                          #rejection=self._overlaps_goal_func(min_distance=0.5),
                          width=self.surface_width, height=self.surface_height, seed=seed)

        initial_pos = sampler.sample_position(self.molecules)
        self.molecules = [Molecule(initial_pos, FePc, self.molecule_transition_data, self.num_sensors, name=self._sample_random_name(), orientation="random")]

    def _set_goals(self, seed: Optional[int] = None) -> None:
        goal_shape = FePc
        angle_symmetry = AngleSymmetry(molecule_point_symmetry=4, substrate_point_symmetry=self.substrate_point_symmetry)
        self.goals = [Goal(Point(self.surface_width / 2, self.surface_height / 2), goal_shape, angle_symmetry.random_angle())]

    def _set_obstacles(self, seed: Optional[int] = None) -> None:
        return super()._set_obstacles(seed)

    def _get_matching(self, options: Optional[dict] = None, seed: Optional[int] = None) -> None:
        self.matching = [Matching(self.molecules[0], self.goals[0])]

    def increment_matching(self) -> None:
        raise StopIteration("Called increment_matching for environment with a single molecule.")

    def reached(self) -> bool:
        return ( self.current_molecule.center.distance(self.current_matching.goal.position) < 0.3 and
                 np.abs(self.current_molecule.orientation - self.current_matching.goal.rotation) <= 1 )


class MockUpPositioningTrainingEnv(PositioningTrainingEnv):
    def __init__(self,
                 mockup_data: Type[MockUpData],
                 **kwargs):
        self.mockup_data = mockup_data
        super().__init__(**kwargs)

    def _parse_molecule_data(self):
        self.data_processor = self.mockup_data(dimensions_x=(-2.1, 2.1),
                                               dimensions_y=(-2.1, 2.1),
                                               step_x=0.3,
                                               step_y=0.3)
        self.molecule_transition_data = self.data_processor.get_molecular_data()
