import numpy as np

from typing import Optional
from numpy.typing import NDArray

import shapely
from shapely import Point

from loguru import logger

from molecule_movement.envs.MoleculeEnvironment import MoleculeEnvironment
from molecule_movement import Molecule, Goal, Matching, VerticalAction, AngleSymmetry
from molecule_movement.parsing import PickleDataParser, LateralMoleculeDataProcessor
from molecule_movement.shapes import FePc
from molecule_movement.sampling import Sampler
from molecule_movement.constants import MoietyType

class TestingMeasurementsEnv(MoleculeEnvironment):
    def __init__(
            self,
            **kwargs
            ):
        self.orientation_map = kwargs["orientation_map"]
        self.translation_map = kwargs["translation_map"]
        super().__init__(**kwargs)

    def _parse_molecule_data(self):
        self.data_processor = LateralMoleculeDataProcessor(PickleDataParser(orientation_map=self.orientation_map, translation_map=self.translation_map),
                                                           dimensions_x=(-2.1, 2.1),
                                                           dimensions_y=(-2.1, 2.1),
                                                           step_x=0.3,
                                                           step_y=0.3,
                                                           dimensions_dest_x=(-1.2, 1.2),
                                                           dimensions_dest_y=(-1.2, 1.2),
                                                           step_dest_x=0.3,
                                                           step_dest_y=0.3,
                                                           dimensions_z=(0.55,0.6),
                                                           dimensions_V=(-500, 500),
                                                           step_z=0.05,
                                                           step_V=250)
        self.molecule_transition_data = self.data_processor.get_molecular_data()

    def _create_initial_distribution(self, seed: Optional[int] = None):
        self.initial_pos = Point(np.random.choice(20),5)
        #angle = np.random.choice(6,1)[0]
        #self.molecules = [Molecule(self.initial_pos, FePc, 15 * angle, self.molecule_transition_data, self.num_sensors, name=self._sample_random_name())]
        angle = np.random.choice(3,1)[0]
        while angle == self.goal_angle:
            angle = np.random.choice(3,1)[0]
        self.molecules = [Molecule(self.initial_pos, FePc, 30 * angle, self.molecule_transition_data, self.num_sensors, name=self._sample_random_name())]

    def _set_goals(self, seed: Optional[int] = None) -> None:
        #angle = np.random.choice(6,1)[0]
        #self.goals = [Goal(Point(5,5), FePc, 15 * angle)]
        angle = np.random.choice(3,1)[0]
        self.goal_angle = angle
        self.goals = [Goal(Point( 5, 5), FePc, 30 * angle),
                      Goal(Point(25, 5), FePc, 30 * angle),
                      Goal(Point( 5,25), FePc, 30 * angle),
                      Goal(Point(25,25), FePc, 30 * angle)]

    def _set_obstacles(self, seed: Optional[int] = None) -> None:
        self.obstacles = list()

    def reached_goal_position(self) -> bool:
        return self.current_distance < 0.3

    def reached_goal_orientation(self) -> bool:
        return np.abs(int(self.current_molecule.orientation) - int(self.current_matching.goal.rotation)) <= 7.5

    def _get_matching(self, options: Optional[dict] = None, seed: Optional[int] = None):
        m = self.molecules[0]
        #self.matching = [Matching(m, self.goals[0]), Matching(m, self.goals[0])]
        self.matching = [Matching(m, g) for g in self.goals]

    def increment_matching(self) -> tuple[dict[str, NDArray], dict]:
        self.steps = 0
        self.current_matching_index += 1
        if self.current_matching_index == len(self.matching):
            self.current_matching_index = 0
        else:
            return self.observation(), {}


class OrientationPositioningEnv(MoleculeEnvironment):
    def __init__(
            self,
            **kwargs
            ):
        super().__init__(**kwargs)
        self.type = list(self.system_configuration.keys())[0]
        self.shape = self.system_configuration[self.type].moiety.shape

    def _create_initial_distribution(self, seed: Optional[int] = None):
        sampler = Sampler(lambda x: 1 if 0.33 < x and x < 0.66 else 0,
                          lambda y: 1 if 0.33 < y and y < 0.66 else 0,
                          width=self.surface_width, height=self.surface_height, seed=seed)
        initial_pos = sampler.sample_position([])
        configuration = self.system_configuration[self.type]
        #while initial_pos.distance(self.goals[0].position) < 0.45:
        #    initial_pos = sampler.sample_position([])
        self.molecules = [Molecule(Point(initial_pos.x, initial_pos.y),
                                           configuration.moiety.shape,
                                           configuration.response_map,
                                           self.num_sensors,
                                           orientation="random",
                                           name="i",
                                           type=configuration.moiety.type,
                                           molecule_point_symmetry=configuration.moiety.point_symmetry,
                                           substrate_point_symmetry=self.substrate_point_symmetry,
                                           action_space=configuration.action_space)]

    def _parse_molecule_data(self):
        pass

    def _set_goals(self, seed: Optional[int] = None) -> None:
        angle_symmetry = AngleSymmetry(4,6)
        self.goals = [Goal(Point(5, 5), FePc, angle_symmetry.random_angle(), type=MoietyType.FePc)]

    def _set_obstacles(self, seed: Optional[int] = None) -> None:
        self.obstacles = list()

    def reached_goal_position(self) -> bool:
        return self.current_distance < 0.25

    def _get_matching(self, options: Optional[dict] = None, seed: Optional[int] = None):
        m = self.molecules[0]
        self.matching = [Matching(m, g) for g in self.goals]

    def reached_goal_orientation(self) -> bool:
        molecule_orientation = self.current_molecule.orientation % 90
        goal_orientation = self.current_matching.goal.rotation % 90
        #logger.success(goal_orientation)
        #logger.success(molecule_orientation)
        difference = np.abs(molecule_orientation - goal_orientation)
        difference = np.min([difference, 90 - difference])
        #if difference <= 7.5:
        #    logger.success(f"goal_orientation reached: {difference=}")
        #    logger.success(f"{self.current_molecule.orientation}")
        #    logger.success(f"{self.current_matching.goal.rotation}")
        #    logger.success(f"{np.abs(molecule_orientation - goal_orientation)=}")
        #    logger.success(f"{90 -np.abs(molecule_orientation - goal_orientation)=}")
        #    #input("")
        return difference <= 7.5





class LateralFePcStarTrainingEnv(MoleculeEnvironment):
    def __init__(
            self,
            **kwargs
            ):
        super().__init__(**kwargs)

    def _parse_molecule_data(self):
        pass

    def _create_initial_distribution(self, seed: Optional[int] = None):
        self.molecules = list()
        self.initial_pos = Point(self.surface_width /2, self.surface_height / 2)
        self.molecules = [Molecule(self.initial_pos, FePc, self.molecule_transition_data, self.num_sensors, orientation="random", name=self._sample_random_name())]

    def _set_goals(self, seed: Optional[int] = None) -> None:
        angle_symmetry = AngleSymmetry(molecule_point_symmetry=4, substrate_point_symmetry=self.substrate_point_symmetry)
        initial_goal_position = Point(self.surface_width / 2, 0.15 * self.surface_height / 2)
        self.goals = [Goal(initial_goal_position, FePc, angle_symmetry.random_angle())]
        num_arms = 15
        angles = list(range(360 // num_arms, 360, 360 // num_arms))
        for angle in angles:
            goal_position = shapely.affinity.rotate(initial_goal_position, angle, origin=(self.surface_width//2,self.surface_height//2))
            orientation = np.random.randint(3, size=1)[0]
            self.goals.append(Goal(goal_position, FePc, orientation * 30))

    def _set_obstacles(self, seed: Optional[int] = None) -> None:
        self.obstacles = list()


    def _get_matching(self, options: Optional[dict] = None, seed: Optional[int] = None):
        self.matching = [Matching(self.molecules[0], self.goals[np.random.choice(len(self.goals))])]
