import numpy as np

from typing import Optional

from ..logging import log_and_raise
from loguru import logger

from shapely import Point

from molecule_movement.Objects import LateralAction

from .MoleculeTransitionData import MoleculeTransitionData, MoleculeActionSpace, Min, Max, StepSize
from .DataParser import MoleculeDataProcessor

EXAMPLE_DATA_DIR = "data"

class PickleDataParser():
    def __init__(self, translation_map_file: str, orientation_map_file: str, name: Optional[str]=None):
        self.name = name
        self.translation_map_file= translation_map_file
        self.orientation_map_file= orientation_map_file

    def load_files(self) -> None:
        try:
            self.translation_map = np.load(self.translation_map_file, allow_pickle=True)
        except Exception as e:
            logger.warning(f"Reading of {self.translation_map_file=} not yet implemented")
            self.translation_map = None
        self.orientation_map = np.load(self.orientation_map_file, allow_pickle=True)

    def get_translation_map(self) -> dict:
        return self.translation_map

    def get_orientation_map(self) -> dict:
        return self.orientation_map

    def get_molecular_data(self) -> None:
        assert False, "exception handling TODO"

    def get_name(self) -> str | None:
        return self.name


class LateralMoleculeDataProcessor(MoleculeDataProcessor):
    def __init__(self,
                 dataParser: PickleDataParser,
                 dimensions_x: tuple[Min, Max],
                 dimensions_y: tuple[Min, Max],
                 step_x: StepSize,
                 step_y: StepSize,
                 dimensions_dest_x: Optional[tuple[Min, Max]] = None,
                 dimensions_dest_y: Optional[tuple[Min, Max]] = None,
                 step_dest_x: Optional[StepSize] = None,
                 step_dest_y: Optional[StepSize] = None,
                 dimensions_z: Optional[tuple[Min, Max]] = None,
                 dimensions_V: Optional[tuple[Min, Max]] = None,
                 step_z: Optional[StepSize] = None,
                 step_V: Optional[StepSize] = None
                 ):

        self.dataParser = dataParser
        self.name = self.dataParser.get_name()
        self.translations = dict()
        self.rotations = dict()
        self.moleculeData = None
        self.action_space = MoleculeActionSpace(dimensions_x,
                                                         dimensions_y,
                                                         step_x,
                                                         step_y,
                                                         dimensions_dest_x,
                                                         dimensions_dest_y,
                                                         step_dest_x,
                                                         step_dest_y,
                                                         dimensions_z,
                                                         dimensions_V,
                                                         step_z,
                                                         step_V)

    def get_response_map(self) -> MoleculeTransitionData:
        if hasattr(self, "moleculeData") and self.moleculeData is not None:
            return self.moleculeData
        self.dataParser.load_files()
        self.distant_translations = self.dataParser.get_translation_map()
        self.orientation_map = self.dataParser.get_orientation_map()
        self.__translate()
        return self.moleculeData

    def __translate(self) -> None:
        ## TODO update orientation_map with translation_map and do everything in one pass?
        self.__translate_map()
        self.moleculeData = MoleculeTransitionData(self.translations,
                                                   self.rotations,
                                                   self.action_space,
                                                   self.name,
                                                   distant_translations=self.distant_translations)

    def __translate_map(self) -> None:

        all_movements = 0
        no_movements = 0
        for action, movements in self.orientation_map.items():
            action = LateralAction(Point(action[0], action[1]), Point(action[2], action[3]), action[4], action[5])
            all_movements += len(movements)

            if len(movements) == 0:
                no_movements += 1
                self.translations[action] = lambda : (0,0)
                self.rotations[action] = lambda : 0
            else:
                translations = [d[0] for d in movements]
                rotations = np.array([d[1] for d in movements])
                rotations = np.round(rotations / 5) * 5
                self.translations[action] = lambda translations=translations: translations[np.random.choice(len(translations), 1)[0]]#+ np.array([np.random.normal(0, 0.27**2), np.random.normal(0, 0.27**2)])
                self.rotations[action] = lambda rotations=rotations: rotations[np.random.choice(len(rotations), 1)[0]]
        logger.info(f"{all_movements=}")
        logger.info(f"{no_movements=}")
