import numpy as np

import gymnasium as gym
from gymnasium.core import ActType, ObsType, WrapperActType

from molecule_movement.parsing import DataParser

from shapely import LineString
import shapely

from loguru import logger

class ActionSpaceClipper(gym.ActionWrapper[ObsType, WrapperActType, ActType], gym.utils.RecordConstructorArgs):
    def __init__(self, env: gym.Env[ObsType, ActType], data_processor: DataParser, action_space: gym.spaces.Space, min_distance_travelled: float = 0.3):
        gym.Wrapper.__init__(self, env)


        self.original_action_set = action_space.n
        self.min_distance_travelled = min_distance_travelled
        self.data_processor = data_processor
        self.action_set = list()
        self._crop_action_space()
        self.action_space = gym.spaces.Discrete(len(self.action_set))


    def action(self, action: WrapperActType) -> ActType:
        return action

    def _crop_action_space(self) -> None:
        x_translations_dict = self.data_processor.refactored_x_translations
        y_translations_dict = self.data_processor.refactored_y_translations
        rotations_dict = self.data_processor.refactored_rotations

        action_steps = [round(-2.1+0.3*i,1) for i in range(15)]
        for x_rel in action_steps:
            for y_rel in action_steps:
                action = (x_rel, y_rel)
                x_translations = x_translations_dict[action]
                y_translations = y_translations_dict[action]
                rotations = rotations_dict[action]
                if rotations[0] != 1.0:
                    self.action_set.append(action)
                    continue
                assert len(x_translations) == len(y_translations), "Measurements for x and y translations are not of same length: {len(x_translations)} != {len(y_translations)}"
                for i in range(len(x_translations)):
                    if LineString([[0,0], [x_translations[i], y_translations[i]]]).length >= self.min_distance_travelled:
                        self.action_set.append(action)
                        break
        logger.info(f"Filtered {self.original_action_set - len(self.action_set)}/{self.original_action_set} actions for min_distace: {self.min_distance_travelled}")
