import numpy as np

from typing import Optional, Callable
from itertools import product
from dataclasses import dataclass, fields
from functools import cached_property
from decimal import Decimal

from molecule_movement.Objects import VerticalAction, LateralAction

from loguru import logger

Min = float
Max = float
StepSize = float

def stable_linspace_step(start: float, stop: float, step: float) -> np.ndarray:
    if stop - start == 0: return np.array([start])

    decimal_places = max(0, -Decimal(str(step)).normalize().as_tuple().exponent)
    scale = 10 ** decimal_places

    s = int((Decimal(str(start)) * scale).to_integral_value())
    e = int((Decimal(str(stop))  * scale).to_integral_value())
    h = int((Decimal(str(step))  * scale).to_integral_value())

    if h <= 0:
        raise ValueError("Step must be positive when stop - start > 0")

    diff = e - s
    if diff % h != 0:
        raise ValueError(f"{stop=} is not reachable from start {start=} given step size {step=}")

    n = diff // h
    arr_int = s + h * np.arange(n + 1, dtype=np.int64)

    arr = arr_int / scale
    if decimal_places:
        arr = np.round(arr, decimal_places)
    return arr

@dataclass(frozen=True)
class MoleculeActionSpace:
    dimensions_x: tuple[Min, Max]
    dimensions_y: tuple[Min, Max]
    step_x: StepSize
    step_y: StepSize
    dimensions_dest_x: Optional[tuple[Min, Max]] = None
    dimensions_dest_y: Optional[tuple[Min, Max]] = None
    step_dest_x: Optional[StepSize] = None
    step_dest_y: Optional[StepSize] = None
    dimensions_z: Optional[tuple[Min, Max]] = None
    dimensions_V: Optional[tuple[Min, Max]] = None
    step_z: Optional[StepSize] = None
    step_V: Optional[StepSize] = None

    def __post_init__(self):
        all_spaces = list()
        for f in fields(self):
            name = f.name
            if not name.startswith("dimensions_"):
                continue

            dimensions = getattr(self, name)
            if dimensions is None:
                continue

            key = name.removeprefix("dimensions_")
            step_name = f"step_{key}"
            step = getattr(self, step_name, None)

            if step is None:
                raise ValueError(f"{step_name} must be provided when {name} is set.")

            low, high = dimensions
            try:
                space = stable_linspace_step(low, high, step)
                object.__setattr__(self, f"{key}_space", space)
                all_spaces.append(space)
            except ValueError as e:
                raise ValueError(f"{key}: {e}")
            object.__setattr__(self, f"all_spaces", all_spaces)

    @property
    def a_max(self):
        """
        Returns the maximal dimension of the action space in xy.

        :return: The maximal dimension of the action space in xy.
        """
        return np.max(np.concatenate((np.abs(self.dimensions_x), np.abs(self.dimensions_y))))

    @property
    def a_max_dest(self):
        """
        Returns the maximal dimension of the action space in xy.

        :return: The maximal dimension of the action space in xy.
        """
        return np.max(np.concatenate((np.abs(self.dimensions_dest_x), np.abs(self.dimensions_dest_y))))

    @cached_property
    def actions(self):
        return list(product(*self.all_spaces))

    @cached_property
    def action_indices(self):
        all_spaces = [range(len(space)) for space in self.all_spaces]
        return list(product(*all_spaces))

@dataclass
class MoleculeTransitionData:
    translations: dict[VerticalAction, Callable] | dict[LateralAction, Callable]
    rotations: dict[VerticalAction, Callable] | dict[LateralAction, Callable]
    action_space: MoleculeActionSpace
    name: Optional[str] = None
    maximum_movement: float = 1.0
    distant_translations: dict[VerticalAction, Callable] | dict[LateralAction, Callable] = None

    def clear_name(self) -> None:
        self.name = None
