import numpy as np

from typing import Optional

from ..logging import log_and_raise
from loguru import logger

from .MoleculeTransitionData import MoleculeTransitionData, MoleculeActionSpace, Min, Max, StepSize

EXAMPLE_DATA_DIR = "data"

DAC_FACTOR = 0.0008296966552734370

class DataParser():
    def __init__(self, x_translations: str, y_translations: str, rotations: str, name: Optional[str]=None):
        self.name = name
        self.x_translations_file = x_translations
        self.y_translations_file = y_translations
        self.rotations_file = rotations
        self.__load_files()

    def __load_files(self) -> None:
        self.x_translations = np.load(self.x_translations_file, allow_pickle=True).item()
        self.y_translations = np.load(self.y_translations_file, allow_pickle=True).item()
        self.rotations = np.load(self.rotations_file, allow_pickle=True).item()

    def get_x_translations(self) -> dict:
        return self.x_translations

    def get_y_translations(self) -> dict:
        return self.y_translations

    def get_rotations(self) -> dict:
        return self.rotations

    def get_molecular_data(self) -> None:
        assert False, "exception handling TODO"

    def get_name(self) -> str | None:
        return self.name

class FixedVoltageDataParser(DataParser):
    def __init__(self, x_translations: str, y_translations: str, rotations: str, voltage: int, name: Optional[str]=None):
        self.voltage = voltage
        super().__init__(x_translations, y_translations, rotations, name)

    def get_x_translations(self) -> dict:
        return self.x_translations[self.voltage][0.0] # [0.0] is this the z-axis?

    def get_y_translations(self) -> dict:
        return self.y_translations[self.voltage][0.0]

    def get_rotations(self) -> dict:
        return self.rotations[self.voltage][0.0]



class MoleculeDataProcessor():
    def __init__(self,
                 dataParser: DataParser,
                 dimensions_x : tuple[Min, Max],
                 dimensions_y : tuple[Min, Max],
                 step_x : StepSize,
                 step_y : StepSize,
                 dimensions_z : Optional[tuple[Min, Max]] = None,
                 dimensions_V : Optional[tuple[Min, Max]] = None,
                 step_z : Optional[StepSize] = None,
                 step_V : Optional[StepSize] = None
                 ):
        self.x_translations = dataParser.get_x_translations()
        self.y_translations = dataParser.get_y_translations()
        self.rotations = dataParser.get_rotations()
        self.name = dataParser.get_name()

        self.molecule_action_space = MoleculeActionSpace(dimensions_x,
                                                         dimensions_y,
                                                         step_x,
                                                         step_y,
                                                         dimensions_z,
                                                         dimensions_V,
                                                         step_z,
                                                         step_V)
        self.__translate()

    def get_molecular_data(self) -> MoleculeTransitionData:
        return self.moleculeData

    def __translate(self) -> None:
        self.__translate_translation_to_tuple_dict()
        self.__translate_rotations_to_tuple_dict()
        self.moleculeData = MoleculeTransitionData(self.translations,
                                                   self.rotation_dist,
                                                   self.molecule_action_space,
                                                   self.name)

    def __translate_translation_to_tuple_dict(self) -> None:
        refactored_x_translations = {(x_rel_action, y_rel_action): [translation * DAC_FACTOR for translation in translations]
                                          for x_rel_action, y_rel_actions in self.x_translations.items()
                                          for y_rel_action, translations in y_rel_actions.items()}
        refactored_y_translations = {(x_rel_action, y_rel_action): [translation * DAC_FACTOR for translation in translations]
                                          for x_rel_action, y_rel_actions in self.y_translations.items()
                                          for y_rel_action, translations in y_rel_actions.items()}

        for action in refactored_x_translations.keys():
            if len(refactored_x_translations[action]) != len(refactored_y_translations[action]):
                log_and_raise(ValueError(), f"Updates for {action} are not of same length: {len(refactored_x_translations)} != {len(refactored_y_translations)}")
        self.translations = dict()
        for action in refactored_x_translations.keys():
            translations = np.round([(x_t, y_t) for x_t, y_t in zip(refactored_x_translations[action], refactored_y_translations[action])],2)
            self.translations[action] = lambda translations=translations: translations[np.random.choice(len(translations), 1)[0]]


    def __translate_rotations_to_tuple_dict(self) -> None:
        refactored_rotations = {(x_rel_action, y_rel_action): rotations
                                for x_rel_action, y_rel_actions in self.rotations.items()
                                for y_rel_action, rotations in y_rel_actions.items()}
        self.rotation_dist = dict()
        for action, distribution in refactored_rotations.items():
            support = list(distribution.keys())
            d = list(distribution.values())
            self.rotation_dist[action] = lambda support=support, d=d : support[np.random.choice(len(support),1,p=d)[0]]





def main():
    dataParser = FixedVoltageDataParser(f"{EXAMPLE_DATA_DIR}/dict_raw_translation_x.npy",
                                        f"{EXAMPLE_DATA_DIR}/dict_raw_translation_y.npy",
                                        f"{EXAMPLE_DATA_DIR}/dict_rotation_probabilities.npy", voltage=1700)

    import random
    dataProcessor = MoleculeDataProcessor(dataParser)
    moleculeData = dataProcessor.get_molecular_data()
    action = (-0.3, -0.3)
    translation = moleculeData.translations[action]
    rotation_delta = moleculeData.rotations[action]
    rotations = list(rotation_delta.keys())

    print(f"Random movement when picking action {action} ")
    print(f"({random.choice(x_delta)}, {random.choice(y_delta)}, ", end="")
    print(f"{rotations[np.random.choice(len(rotation_delta), 1, p=list(rotation_delta.values()))[0]]})")


if __name__ == "__main__":
    main()
