import numpy as np
import gymnasium as gym
import string

import shapely
from shapely import Point, Polygon, box
import cv2
import keyboard

import os
from typing import Optional
from numpy.typing import NDArray
import datetime

from loguru import logger
from molecule_movement.logging import log_and_raise, deprecated, log_shapely_point

from shapely import LineString

from molecule_movement.colour_utils import Highlight

from molecule_movement.envs.Empty import EmptyEnvironment
from molecule_movement import Goal, Obstacle, MoleculeExperiment, LateralAction, VerticalAction, Molecule, Movement, Matching
from molecule_movement.parsing import FixedVoltageDataParser, MoleculeDataProcessor
from molecule_movement.parsing import MoleculeActionSpace, LateralActionsMockUpData
from molecule_movement.sampling import CircularSampler, Sampler
from molecule_movement.matching import RandomMatching, GreedyMatching, HungarianMatching
from molecule_movement.scheduling import SATBasedScheduling, SortScheduling
from molecule_movement.shapes import RECTANGLE, TRIANGLE, DDNB, resize
from molecule_movement.Simulator import HardwareSimulator
from molecule_movement.statistics import enable_statistics_logger, dump_statistics
from molecule_movement.Molecule import find_nearest

try:
    from object_recognition.post_processing.analyse_moieties import AnalyseMoietyCreatec, AnalyseMoietyNanonis
    from object_recognition.post_processing.analyse_moieties_simulation import AnalyseMoietySimulation
    from stm_remote_control.stm import CreatecSTM, NanonisSTM
except ImportError:
    pass

class ExperimentalEnvironment(EmptyEnvironment):
    def __init__(
            self,
            molecules: list[MoleculeExperiment],
            goals: list[Goal],
            obstacles: list[Obstacle],
            hardware: str,
            measurement_task: str,
            num_sensors: int = 16,
            surface_width: int = 30,
            surface_height: int = 30,
            **kwargs
            ):
        self.include_surface_orientation = False

        self.molecules = molecules
        self.goals = goals
        self.obstacles = obstacles

        self.measurement_task = measurement_task

        self.start_tip_position = None
        self.end_tip_position = None

        self.running_simulation = False
        if hardware == "createc":
            self.analyse_moieties = AnalyseMoietyCreatec()
            self.device = CreatecSTM()
        elif hardware == "nanonis":
            self.analyse_moieties = AnalyseMoietyNanonis()
            self.device = NanonisSTM()
        elif hardware == "simulation":
            self.running_simulation = True
            self.analyse_moieties = AnalyseMoietySimulation(self.molecules, self.goals, self.obstacles)
            self.device = HardwareSimulator

        self.input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))
        self.experiment_name = self.analyse_moieties.get_yaml_data(self.input_yaml_file, "experiment_name")

        self.CRASH_DISTANCE = self.analyse_moieties.get_yaml_data(self.input_yaml_file, "obstacle_distance_threshold_nm")
        self.SUCCESS_DISTANCE = self.analyse_moieties.get_yaml_data(self.input_yaml_file, "success_distance_threshold_nm")


        log_dir = os.path.join(os.getcwd(), "log_experiments")
        os.makedirs(log_dir, exist_ok=True)
        self.log_file = os.path.join(log_dir, self.experiment_name+'.log')

        # self._create_initial_distribution()
        # self._set_goals()
        # self.window_size = scale * np.array((surface_width, surface_height))
        # self.render_grid = render_grid
        # self.render_sensors = render_sensors
        # self.surface_height, self.surface_width = surface_height, surface_width
        self.molecules = []
        self.goals = []


        super().__init__(molecules=self.molecules,
                         goals=self.goals,
                         obstacles=self.obstacles,
                         surface_width=surface_width,
                         surface_height=surface_height,
                         **kwargs)

        self.observation_space = gym.spaces.Box(np.zeros((5 + num_sensors ,)), np.full((5 + num_sensors,),[360, 360, 360, np.inf, 1] + [np.inf] * num_sensors), shape=(5 + num_sensors,))

    def _parse_molecule_data(self):
        pass

    def _create_initial_distribution(self, seed: Optional[int] = None, options: Optional[dict] = None) -> None:

        self.analyse_moieties.init_moiety_information(self.device.overview_image_path, plot=True)
        user_input = False
        if user_input:
            moiety_index = int(input("Enter the index of the object to manipulate: "))
            self.analyse_moieties._set_current_moiety(moiety_index)

        if False: #self.running_simulation:

            simulation_data = LateralActionsMockUpData(dimensions_x=(-2.1, 2.1),
                                                       dimensions_y=(-2.1, 2.1),
                                                       step_x=0.3,
                                                       step_y=0.3,
                                                       dimensions_dest_x=(-2.1, 2.1),
                                                       dimensions_dest_y=(-2.1, 2.1),
                                                       step_dest_x=0.3,
                                                       step_dest_y=0.3,
                                                       symmetry=1).get_molecular_data()
            self.molecules = [Molecule(Point(self.analyse_moieties.moiety_position_nm[i]),
                                       shape=Polygon(self.analyse_moieties.load_reference_contour(self.analyse_moieties.moiety_types[i], at_origin=True).reshape(-1, 2)),
                                       rotation=0,
                                       stochastic_updates=simulation_data,
                                       name=self.analyse_moieties.moiety_types[i],
                                       num_sensors=0,
                                       molecule_point_symmetry=self.analyse_moieties.get_yaml_data(os.path.normpath(os.path.join(os.getcwd(),"input.yaml")), self.analyse_moieties.moiety_types[i], "moiety_point_symmetry"),
                                       substrate_point_symmetry=self.analyse_moieties.get_yaml_data(os.path.normpath(os.path.join(os.getcwd(),"input.yaml")), self.analyse_moieties.moiety_types[i], "substrate_point_symmetry"),
                                       action_space=self.set_action_space_for_moiety(self.analyse_moieties.moiety_types[i]) ) for i in range(len(self.analyse_moieties.moiety_types)) if ("molecule" in self.analyse_moieties.moiety_types[i]) or ("atom" in self.analyse_moieties.moiety_types[i])]
        else:
            self.molecules = [MoleculeExperiment(Point(self.analyse_moieties.moiety_position_nm[i]),
                                                 shape=Polygon(self.analyse_moieties.load_reference_contour(self.analyse_moieties.moiety_types[i], at_origin=True).reshape(-1, 2)),
                                                 rotation=0,
                                                 name=f"{self.analyse_moieties.moiety_types[i]}_{i}",
                                                 index=int(i),
                                                 type=self.analyse_moieties.moiety_types[i],
                                                 num_sensors=0,
                                                 molecule_point_symmetry=self.analyse_moieties.get_yaml_data(os.path.normpath(os.path.join(os.getcwd(),"input.yaml")), self.analyse_moieties.moiety_types[i], "moiety_point_symmetry"),
                                                 substrate_point_symmetry=self.analyse_moieties.get_yaml_data(os.path.normpath(os.path.join(os.getcwd(),"input.yaml")), self.analyse_moieties.moiety_types[i], "substrate_point_symmetry"),
                                                 action_space=self.set_action_space_for_moiety(self.analyse_moieties.moiety_types[i]) ) for i in range(len(self.analyse_moieties.moiety_types)) if ("molecule" in self.analyse_moieties.moiety_types[i]) or ("atom" in self.analyse_moieties.moiety_types[i])]
        # Set obstacles based on the object detection
        self.obstacles = list()
        for obstacle in self.analyse_moieties._obstacles_nm:
            shape = Polygon(obstacle)
            center = box(*shape.bounds).centroid
            self.obstacles.append(Obstacle(position=center, shape=shape, rotation=0))
        print("Created initial distribution of molecules and obstacles")

    def set_action_space_for_moiety(self, moiety_type: str) -> dict[str, MoleculeActionSpace]:
        # Get variables for specific moiety_type from input.yaml
        input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))

        # Get action space translation
        min_x = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "min_x_nm")
        max_x = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "max_x_nm")
        step_x = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "step_x_nm")
        min_y = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "min_y_nm")
        max_y = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "max_y_nm")
        step_y = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "step_y_nm")
        try:
            min_dest_x = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "min_dest_x_nm")
            max_dest_x = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "max_dest_x_nm")
            step_dest_x = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "step_dest_x_nm")
            min_dest_y = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "min_dest_y_nm")
            max_dest_y = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "max_dest_y_nm")
            step_dest_y = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "step_dest_y_nm")
        except KeyError:
            min_dest_x = None
            max_dest_x = None
            step_dest_x = None
            min_dest_y = None
            max_dest_y = None
            step_dest_y = None
        min_z = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "min_z_nm")
        max_z = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "max_z_nm")
        step_z = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "step_z_nm")
        min_V = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "min_v_mV")
        max_V = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "max_v_mV")
        step_V = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "translation", "step_v_mV")

        # Get action space orientation
        min_x_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "min_x_nm")
        max_x_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "max_x_nm")
        step_x_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "step_x_nm")
        min_y_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "min_y_nm")
        max_y_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "max_y_nm")
        step_y_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "step_y_nm")
        try:
            min_dest_x_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "min_dest_x_nm")
            max_dest_x_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "max_dest_x_nm")
            step_dest_x_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "step_dest_x_nm")
            min_dest_y_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "min_dest_y_nm")
            max_dest_y_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "max_dest_y_nm")
            step_dest_y_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "step_dest_y_nm")
        except KeyError:
            min_dest_x_orientation = None
            max_dest_x_orientation = None
            step_dest_x_orientation = None
            min_dest_y_orientation = None
            max_dest_y_orientation = None
            step_dest_y_orientation = None
        min_z_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "min_z_nm")
        max_z_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "max_z_nm")
        step_z_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "step_z_nm")
        min_V_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "min_v_mV")
        max_V_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "max_v_mV")
        step_V_orientation = self.analyse_moieties.get_yaml_data(input_yaml_file, moiety_type, "orientation", "step_v_mV")

        # Setup action spaces
        return {"translation": MoleculeActionSpace(dimensions_x=(min_x, max_x),
                                                                    dimensions_y=(min_y, max_y),
                                                                    step_x=step_x,
                                                                    step_y=step_y,
                                                                    dimensions_dest_x=(min_dest_x, max_dest_x),
                                                                    dimensions_dest_y=(min_dest_y, max_dest_y),
                                                                    step_dest_x=step_dest_x,
                                                                    step_dest_y=step_dest_y,
                                                                    dimensions_z=(min_z, max_z),
                                                                    dimensions_V=(min_V, max_V),
                                                                    step_z=step_z,
                                                                    step_V=step_V),

                                 "orientation": MoleculeActionSpace(dimensions_x=(min_x_orientation, max_x_orientation),
                                                                    dimensions_y=(min_y_orientation, max_y_orientation),
                                                                    step_x=step_x_orientation,
                                                                    step_y=step_y_orientation,
                                                                    dimensions_dest_x=(min_dest_x_orientation, max_dest_x_orientation),
                                                                    dimensions_dest_y=(min_dest_y_orientation, max_dest_y_orientation),
                                                                    step_dest_x=step_dest_x_orientation,
                                                                    step_dest_y=step_dest_y_orientation,
                                                                    dimensions_z=(min_z_orientation, max_z_orientation),
                                                                    dimensions_V=(min_V_orientation, max_V_orientation),
                                                                    step_z=step_z_orientation,
                                                                    step_V=step_V_orientation)
        }

    def _set_goals(self, seed: Optional[int] = None, number_of_goals=100) -> None:

        if self.measurement_task is None:
            # self.goals = [ Goal(Point(10,10), shape=Point(0,0).buffer(1.1), rotation=00), Goal(Point(5,5), shape=Point(0,0).buffer(1.1), rotation=00)]
            self.goals = [  Goal(Point(-6,-4), shape=resize(TRIANGLE, 1, 3), rotation=60),
                            Goal(Point(-6,100), shape=resize(TRIANGLE, 1, 3), rotation=60)
                        ]

        def on_q_key():
            print("The 'q' key was pressed! Exiting...")
            self.manual_goal_input_done = True

        def on_r_key():
            tip_position = self.device.get_current_tip_position()
            self.goals.append(
                Goal(
                    Point(tip_position[0]*1e9, tip_position[1]*1e9),
                    shape=resize(TRIANGLE, 1, 1),
                    rotation=0
                )
            )
            print("The 'r' key was pressed! Setting current tip position as goal...")

        # Set goals based on user input
        if self.measurement_task == "manual":
            self.goals = []
            self.manual_goal_input_done = False

            keyboard.add_hotkey('r', on_r_key)
            keyboard.add_hotkey('q', on_q_key)

            print("Manual goal input mode: Press 'r' to add a goal, 'q' to finish.")

            while True:
                keyboard.read_event(suppress=False)  # Waits for any key event (non-blocking to hotkeys)
                if len(self.goals) >= len(self.molecules) or self.manual_goal_input_done:
                    break

        print("Manual goal input completed.")

        #self.goals = self.generate_point_grid(2,3, spacing=4, start_x=self.goals[0].position.x, start_y=self.goals[0].position.y)

        if self.measurement_task == "exp_orientation":
            self.goals = [Goal( Point(self.analyse_moieties.moiety_position_nm[i]),
                                shape=resize(TRIANGLE, 1, 1),
                                rotation=0) for i in range(len(self.analyse_moieties.moiety_types)) if ("molecule" in self.analyse_moieties.moiety_types[i]) or ("atom" in self.analyse_moieties.moiety_types[i])]
        loc_x, loc_y = np.array([-44.6,-52.4])
        if self.measurement_task == "response_data":
            # Generate goals around the initial overview image position based on a Gaussian distribution
            np.random.seed(45)
             #self.analyse_moieties.moiety_position_nm[self.analyse_moieties._current_moiety]
            # loc_x += 0Task
            # loc_y += 3
            std = 1.2
            image_size = self.analyse_moieties.size_overview_image_nm
            gaussian_goals = [Goal(Point(np.random.normal(loc=loc_x, scale=std),np.random.normal(loc=loc_y, scale=std)),
                                     shape=resize(TRIANGLE, 1, 1),
                                     rotation=np.random.randint(0,2)*30) for _ in range(int(number_of_goals*0.9))]
            # Generate additional goals on the circumference of a circle
            np.random.uniform(0, 2*np.pi, int(number_of_goals))
            circular_goals = [Goal(Point(loc_x + np.cos(angle) * 2 * std, loc_y + np.sin(angle) * 2 * std),
                                shape=resize(TRIANGLE, 1, 1),
                                rotation=np.random.randint(0,2)*30) for angle in np.random.uniform(0, 2*np.pi, int(number_of_goals))]

            # Sort circular goals to be furthest apart from each other
            circular_goals = sorted(circular_goals, key=lambda goal: np.sqrt((goal.position.x - loc_x) ** 2 + (goal.position.y - loc_y) ** 2), reverse=True)

            self.goals = gaussian_goals + circular_goals

        if self.measurement_task == "response_data_long_translations":
                        # Generate goals around the initial overview image position based on a Gaussian distribution
            np.random.seed(45)
            # loc_x += 0Task
            # loc_y += 3
            std = 1.2
            image_size = self.analyse_moieties.size_overview_image_nm

            np.random.uniform(0, 2*np.pi, int(number_of_goals))
            circular_goals = [Goal(Point(loc_x + np.cos(angle) * 2 * std, loc_y + np.sin(angle) * 2 * std),
                                shape=resize(TRIANGLE, 1, 1),
                                rotation=np.random.randint(0,2)*30) for angle in np.random.uniform(0, 2*np.pi, int(number_of_goals))]

            # Sort circular goals to be furthest apart from each other
            circular_goals = sorted(circular_goals, key=lambda goal: np.sqrt((goal.position.x - loc_x) ** 2 + (goal.position.y - loc_y) ** 2), reverse=True)

            self.goals = circular_goals


            plot_goals=False
            if plot_goals:
                # plot the goals
                from matplotlib import pyplot as plt
                fig = plt.figure()
                ax = fig.add_subplot(111)
                # plot square for overview image size

                # plot square with black boarder and no fill
                ax.add_patch(plt.Rectangle((loc_x - image_size[0] / 2, loc_y - image_size[1] / 2), image_size[0], image_size[1], edgecolor='black', facecolor='none', lw=1))
                ax.scatter([goal.position.x for goal in self.goals], [goal.position.y for goal in self.goals], c='k', label='goals', s=2)
                ax.scatter(loc_x, loc_y, c='r', label='center image', s=10, marker='x')
                # plot circle of radius std
                circle = plt.Circle((loc_x, loc_y), 2 * std, color='xkcd:mango', fill=False)
                ax.add_artist(circle)
                ax.set_aspect('equal')
                #plt.legend()
                # save the figure with transparent background
                plt.savefig(os.path.join(os.getcwd(), "log_experiments", self.experiment_name+'_goals.png'), transparent=True, dpi=1200)
                plt.show()


    def generate_point_grid(self, n, m, spacing=4, start_x=0, start_y=0) -> list[Goal]:
        """
        Generate a grid of points starting from an arbitrary top-left point.

        Parameters:
            n (int): Number of rows.
            m (int): Number of columns.
            spacing (int): Distance between each point (default is 4).
            start_x (int or float): X-coordinate of the starting point.
            start_y (int or float): Y-coordinate of the starting point.

        Returns:
            list of tuples: Each tuple is a (x, y) coordinate.
        """
        points = []
        for i in range(n):  # rows
            for j in range(m):  # columns
                x = start_x + j * spacing
                y = start_y + i * spacing
                points.append(Goal( Point(x, y),
                                    shape=resize(TRIANGLE, 1, 1),
                                    rotation=0
                                    )
                            )
        return points

    def bin_to_hex_lattice(self, xy, a, angle_deg=0):
        """
        Bin xy coordinates to the nearest site on a rotated hexagonal (111) lattice.

        Parameters
        ----------
        xy : np.ndarray
            Array of shape (N, 2) with xy coordinates.
        a : float
            Lattice constant (distance between nearest neighbors).
        angle_deg : float
            Rotation angle of the lattice in degrees (counterclockwise).

        Returns
        -------
        binned_xy : np.ndarray
            Array of shape (N, 2) with binned coordinates on the hexagonal lattice.
        """
        # Convert angle to radians
        theta = np.deg2rad(angle_deg)
        # Rotation matrix
        R = np.array([[np.cos(theta), -np.sin(theta)],
                    [np.sin(theta),  np.cos(theta)]])
        # Hexagonal lattice basis vectors (before rotation)
        b1 = np.array([a, 0])
        b2 = np.array([a/2, a * np.sqrt(3)/2])
        # Rotate basis vectors
        b1_rot = R @ b1
        b2_rot = R @ b2

        # Inverse of the basis matrix for projection
        B = np.column_stack((b1_rot, b2_rot))
        B_inv = np.linalg.inv(B)

        # Project each point onto the lattice basis and round to nearest integer indices
        indices = np.dot(xy, B_inv)
        indices_rounded = np.round(indices).astype(int)

        # Convert back to real-space coordinates
        binned_xy = np.dot(indices_rounded, B.T)
        return binned_xy

    def _set_obstacles(self, seed: Optional[int] = None) -> None:
        pass
        #self.obstacles = [Obstacle(position=Point(0,0), rotation=0, shape=[Point(self.analyse_moieties._obstacles_nm[0][i]) for i in range(len(self.analyse_moieties._obstacles_nm))])]

    def _get_matching(self, seed: Optional[int] = None):
        # For reponse
        #self.analyse_moieties._set_current_moiety(0)
        # self.matching = [Matching(self.molecules[self.analyse_moieties._current_moiety], goal) for goal in self.goals]
        # self.matching = [Matching(self.molecules[i], self.goals[i]) for i in range(len(self.goals))]
        matching = HungarianMatching(self.molecules, self.goals)
        self.matching = matching.compute_matching()
        for match in self.matching:
            logger.info(f"Moving {match.molecule} to {match.goal.position}")

        # XXX This is does everything perfect but is ugly need better solution
        self.active_moiety_index
        #self.matching_moiety_indices = 0 # self.matching.get_matching_moiety_indices()
        #self.matching_goal_indices = self.current_matching_index # self.matching.get_matching_goal_indices()

    def increment_matching(self) -> tuple[dict[str, NDArray], dict]:
        self.steps = 0
        self.current_matching_index += 1
        if self.current_matching_index == len(self.matching):
            if self.renderer:
                self.matching = None
                self.render(surpress_matching=True)
            #self.current_matching_index = 0
            raise StopIteration("Done")

        # == We are married - do not seperate us
        self.active_moiety_index
        self.center_of_scan_position = self.current_molecule.center
        _, scan_info = self._get_current_stm_state(action_intersects_moiety = self.analyse_moieties.does_action_intersect_moiety(self.start_tip_position, self.end_tip_position))
        self._update_current_moiety()
        # ==

        logger.bind(task="stats", current_matching_index=self.current_matching_index).trace(f"Incrementing matching index = {self.current_matching_index}")
        logger.warning('ExpEnv.increment_matching called.')
        return self.observation(), scan_info

    def set_agent_task(self, agent_task):
        self._agent_task = agent_task

    @property
    def active_moiety_index(self) -> int:
        current_moiety_index = self.matching[self.current_matching_index].molecule.index
        self.analyse_moieties._set_current_moiety(current_moiety_index)
        return current_moiety_index

    def perform_action_on_stm(self, action, manipulation_type):
        if manipulation_type == "vertical":
            self._perform_vertical(action)
        elif manipulation_type == "lateral":
            self._perform_lateral(action)
        else:
            raise ValueError(f"Unknown manipulation type: {manipulation_type}")



    def _perform_vertical(self, action):
        """Handles vertical manipulation."""

        #print("Action: ", action)
        current_molecule = self.unwrapped.get_wrapper_attr("current_molecule")
        action_xy = Point(action[0], action[1])
        action_start = shapely.affinity.rotate(action_xy, current_molecule.rotation, origin=(0,0))
        action_start = shapely.affinity.translate(action_start, current_molecule.center.x, current_molecule.center.y)
        action_z = action[2]
        action_V = action[3]
        self.start_tip_position = Point(action_start.x, action_start.y)

        logger.bind(task="stats", action=VerticalAction(xy=action_xy, V=action_V, z=action_z)).trace("")
        logger.bind(task="stats", start_tip_position=self.start_tip_position).trace("")

        self.center_of_scan_position = Point(self.current_molecule.center.x + (self.start_tip_position.x - self.current_molecule.center.x) / 2,
                                             self.current_molecule.center.y + (self.start_tip_position.y - self.current_molecule.center.y) / 2)

        self.device.perform_vertical_manipulation(
            x_position_nm=self.start_tip_position.x,
            y_position_nm=self.start_tip_position.y,
            z_approach_nm=action_z,
            voltage_mV=action_V
        )

    def _perform_lateral(self, action, plotting=False):
        """
        Handles both standard lateral and lateral_fixed_start manipulations.
        If the action has fewer than 6 values, the fixed start lateral manipulation is executed.
        """
        #print("Action: ", action)
        current_molecule = self.unwrapped.get_wrapper_attr("current_molecule")
        current_position = current_molecule.center

        self._agent_task = "positioning"
        if self._agent_task == 'positioning':
            current_goal = self.unwrapped.get_wrapper_attr("goal_position")
            goal_angle_rad = np.atan2(current_goal.y - current_position.y, current_goal.x - current_position.x)
        elif self._agent_task == 'rotation':
            current_goal = current_position
            goal_angle_rad = np.deg2rad(current_molecule.rotation)
        assert isinstance(action, np.ndarray) and action.shape == (6,) or len(action) == 6


        action_start = Point(action[0], action[1])
        action_end = Point(action[2], action[3])
        action_z = action[4]
        action_V = action[5]
        logger.bind(task="stats", action=LateralAction(Point(action[0], action[1]), Point(action[2], action[3]), V=action_V, z=action_z)).trace("")
        self.start_tip_position = action_start
        self.end_tip_position = action_end

        input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))
        _DISTANCE_THRESHOLD_FOR_SPECIFIC_LATERAL_MANIPULATION_NM = self.analyse_moieties.get_yaml_data(input_yaml_file, current_molecule.type, "_DISTANCE_THRESHOLD_FOR_SPECIFIC_LATERAL_MANIPULATION_NM")

        if not 'response' in self.measurement_task and self._agent_task == 'positioning' and self.__distance_to_goal() > _DISTANCE_THRESHOLD_FOR_SPECIFIC_LATERAL_MANIPULATION_NM: # 2.1 nm
            # If the moiety is far away from the goal, we not use the agents action but compute start and end position based on the goal position and the current molecule position
            action_end = self._compute_end_position()
            action_start = self._compute_start_offset_position(action_end)
            self.end_tip_position = action_end
            self.start_tip_position = action_start

        if plotting:
            from matplotlib import pyplot as plt
            fig = plt.figure()
            ax = fig.add_subplot(111)
            ax.scatter(self.current_molecule.center.x, self.current_molecule.center.y, c='k', label='molecule')
            # draw square around molecule with molecules orientation
            max_action_space_x = max(self.current_molecule.action_space_translation.dimensions_x)*2
            max_action_space_y = max(self.current_molecule.action_space_translation.dimensions_y)*2
            ax.add_patch(plt.Rectangle((self.current_molecule.center.x - max_action_space_x / 2, self.current_molecule.center.y - max_action_space_y / 2), max_action_space_x, max_action_space_y, angle=self.current_molecule.rotation, rotation_point='center', color='b', alpha=0.5))
            ax.add_patch(plt.Rectangle((current_goal.x - max_action_space_x / 2, current_goal.y - max_action_space_y / 2), max_action_space_x, max_action_space_y, angle=np.rad2deg(goal_angle_rad), rotation_point='center', color='r', alpha=0.5))
            ax.scatter(self.goal_position.x, self.goal_position.y, c='r',  marker='x', label='goal')
            ax.scatter(self.start_tip_position.x, self.start_tip_position.y, c='g', label='start position')
            # plot arrow for moiety orientation
            ax.arrow(self.current_molecule.center.x, self.current_molecule.center.y, 0.5 * self.current_molecule.size_x * np.cos(np.deg2rad(self.current_molecule.rotation)), 0.5 * self.current_molecule.size_x * np.sin(np.deg2rad(self.current_molecule.rotation)), head_width=0.1, head_length=0.1, fc='k', ec='k')
            # manipulation arrow
            ax.arrow(self.start_tip_position.x, self.start_tip_position.y, self.end_tip_position.x - self.start_tip_position.x, self.end_tip_position.y - self.start_tip_position.y, head_width=0.1, head_length=0.1, fc='k', ec='k')
            ax.scatter(self.end_tip_position.x, self.end_tip_position.y, c='b', label='end position')
            ax.set_aspect('equal')
            plt.legend()
            plt.show()


        logger.bind(task="stats", start_tip_position=log_shapely_point(self.start_tip_position), end_tip_position=log_shapely_point(self.end_tip_position)).trace("")

        self.device.perform_lateral_manipulation(
            x_start_position_nm=self.start_tip_position.x,
            y_start_position_nm=self.start_tip_position.y,
            x_end_position_nm=self.end_tip_position.x,
            y_end_position_nm=self.end_tip_position.y,
            voltage_mV=action_V,
            z_position_nm=action_z
        )

        self.center_of_scan_position = self.device._get_rough_lat_position_for_exact_search()

    def _compute_start_offset_position(self, end_position):
        current_molecule = self.unwrapped.get_wrapper_attr("current_molecule")
        current_position = current_molecule.center
        current_goal = self.unwrapped.get_wrapper_attr("goal_position")
        input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))
        max_moiety_size_nm = self.analyse_moieties.get_yaml_data(input_yaml_file, current_molecule.type, "max_moiety_size_nm")
        min_moiety_size_nm = self.analyse_moieties.get_yaml_data(input_yaml_file, current_molecule.type, "min_moiety_size_nm")
        dx = end_position.x - current_position.x
        dy = end_position.y - current_position.y

        length_from_start = max_moiety_size_nm/2 + np.linalg.norm([dx, dy])
        angle_in_reverse = np.arctan2(-dy, -dx)
        angle_to_offset = np.arctan2(min_moiety_size_nm/3,np.linalg.norm([dx, dy]))
        angle_in_reverse += angle_to_offset
        return Point(length_from_start * np.cos(angle_in_reverse) + end_position.x, length_from_start * np.sin(angle_in_reverse) + end_position.y)

    def _compute_end_position(self):
        current_molecule = self.unwrapped.get_wrapper_attr("current_molecule")
        input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))
        length_from_start = self.analyse_moieties.get_yaml_data(input_yaml_file, current_molecule.type, "lateral_offset_for_end_position_nm")
        current_goal = self.unwrapped.get_wrapper_attr("goal_position")
        molecule_goal_vector = np.array([self.current_matching.goal.position.x - self.current_molecule.center.x, self.current_matching.goal.position.y - self.current_molecule.center.y])
        angle_rad = np.arctan2(molecule_goal_vector[1], molecule_goal_vector[0])
        return Point(length_from_start * np.cos(angle_rad) + current_goal.x, length_from_start * np.sin(angle_rad) + current_goal.y)

    def _compute_start_position(self, end_position, start_offset=1.25):
        """Computes the start position for lateral manipulation based on molecule position."""
        molecule_pos = self.unwrapped.get_wrapper_attr("current_molecule").center
        molecule_max_size = max(self.unwrapped.get_wrapper_attr("current_molecule").shape_size_x,
                                self.unwrapped.get_wrapper_attr("current_molecule").shape_size_y)
        goal_angle = np.atan2(end_position.y - molecule_pos.y, end_position.x - molecule_pos.x)
        start_angle = goal_angle - np.pi  # Opposite direction
        return Point(
            molecule_pos.x + start_offset * molecule_max_size * np.cos(start_angle),
            molecule_pos.y + start_offset * molecule_max_size * np.sin(start_angle)
        )

    def get_tip_positions(self, action):

        """
        Get the start and end tip positions based on the action.
        If the action has fewer than 6 values, it is a fixed start lateral manipulation.
        """
        if len(action) < 6:
            action_end = Point(action[0], action[1])
            #action_end = shapely.affinity.translate(action_end, current_goal.x, current_goal.y)
            #action_end = shapely.affinity.rotate(action_end, goal_angle_rad, origin=(current_goal.x, current_goal.y), use_radians=True)
            action_start = self._compute_start_position(action_end)
            start_tip_position = Point(action[0], action[1])
            end_tip_position = Point(action[2], action[3])
        else:
            start_tip_position = Point(action[0], action[1])
            end_tip_position = Point(action[2], action[3])
        return start_tip_position, end_tip_position


    def terminated(self):
        info = {"crashed": False, "reached_goal": False, "reached_goal_orientation": False, "destroyed": False}
        if self._crashed():
            logger.warning("crashed!")
            self.current_molecule.set_crashed()
            info["crashed"] = True
        if self.reached_goal_position():
            logger.warning("reached goal!")
            info["reached_goal"] = True
        if self.reached_goal_orientation():
            logger.warning("reached goal orientation!")
            info["reached_goal_orientation"] = True
        if self._destroyed():
            logger.warning("destroyed!")
            self.current_molecule.set_destroyed()
            info["destroyed"] = True
        logger.bind(task="stats", crashed=info["crashed"], reached_position=info["reached_goal"], reached_orientation=info["reached_goal_orientation"], destroyed=info["destroyed"]).trace("")
        return self._crashed() or (self.reached_goal_position() and self.reached_goal_orientation()), info

    def _destroyed(self):
        """ Check if molecule is destroyed via object recognition and verify via user input"""
        return self.analyse_moieties.is_molecule_destroyed()

    def _crashed(self) -> bool:
        distances = np.asarray([self.current_molecule.polygon.distance(m.polygon) if m != self.current_molecule else np.inf for m in self.molecules + self.obstacles])
        return bool(np.any(np.where(distances < self.CRASH_DISTANCE)))

    def reached_goal_position(self) -> bool:
        return self.current_distance < self.SUCCESS_DISTANCE

    def reached_goal_orientation(self) -> bool:
        #return True
        return np.abs(int(self.current_molecule.rotation) - int(self.current_matching.goal.rotation)) <= 7

    def truncated(self) -> bool:
        return (self.steps >= self.max_steps) and not self.continue_after_done()

    def continue_after_done(self) -> bool:
        input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))
        return self.analyse_moieties.get_yaml_data(input_yaml_file, "continue_after_done")

    def get_latest_episode(self) -> int:
        """ Get the latest episode number from the current log file. """
        episode = None
        try:
            with open(self.log_file, "r") as file:
                lines = file.readlines()
                for line in reversed(lines):
                    if "Episode:" in line:
                        episode = int(line.split(":")[-1].strip())
                        break
        except FileNotFoundError as e:
            logger.warning("Could not determine latest_episode setting it to 0")
            logger.bind(task="stats", warn="log not found episode, starts at 0").trace("")
            episode = 0
        if episode is None:
            logger.warning(f"Could not determine latest_episode from {self.log_file}, setting it to 0")
            logger.bind(task="stats", warn="could not determine episode starts at 0").trace("")
            episode = 0
        return episode

    def action_intersect_moiety(self, action):
        return self.analyse_moieties.does_action_intersect_moiety(start_tip_position=Point(action[0],action[1]), end_tip_position=Point(action[2],action[3]))

    def step(self, action):
        info = dict()
        #obs, reward, terminated, truncated, info = super().step(action)
        self.current_step += 1


        self.perform_action_on_stm(action, 'lateral')
        moved_molecules, scan_info = self._get_current_stm_state(action_intersects_moiety = self.analyse_moieties.does_action_intersect_moiety(self.start_tip_position, self.end_tip_position))
        info.update(scan_info)
        # measurement_incomplete = True
        # while measurement_incomplete:
        #     try:
        #         moved_molecules = self._get_current_stm_state(action_intersects_moiety = self.analyse_moieties.does_action_intersect_moiety(self.start_tip_position, self.end_tip_position))
        #         measurement_incomplete = False
        #     except:
        #         measurement_incomplete = True

        self._update_current_moiety()
        previous_position = self.current_molecule.previous_position
        logger.bind(task="stats",
                    molecule_x_before=previous_position.x,
                    molecule_y_before=previous_position.y,
                    molecule_orientation_before=self.current_molecule.previous_orientation).trace("")
        logger.bind(task="stats",
                    molecule_x=self.current_molecule.center.x,
                    molecule_y=self.current_molecule.center.y,
                    molecule_orientation=self.current_molecule.orientation).trace("")
        logger.bind(task="stats", movement=Movement(self.current_molecule.center.x - previous_position.x, self.current_molecule.center.y - previous_position.y, self.current_molecule.orientation - self.current_molecule.previous_orientation)).trace("")

        obs = self.observation()
        x,y, destx, desty, z, V = action
        action = LateralAction(Point(x,y), Point(destx, desty), z, V)
        if self.render_mode == "human":
            self.render(action=action, moved_molecules=moved_molecules)

        terminated, reason = self.terminated()
        logger.bind(task="stats",
                    current_episode=self.current_episode,
                    steps=self.current_step,
                    goal=self.current_matching.goal,
                    current_distance=self.current_distance).trace("")
        info.update(reason)

        return obs, self.reward(), terminated, self.truncated(), info

    def reset(self, seed: Optional[int] = None, options: Optional[dict] = None):
        #obs, info = super().reset()
        info = dict()

        self.steps = 0
        self.name_generator = np.random.default_rng(seed=seed)
        self.current_matching_index = 0


        self._create_initial_distribution()
        self._set_goals()

        if self.render_mode == "human":
            self.simulator = HardwareSimulator(self.molecules, self.goals)
            self.unwrapped._initialize_renderer(origin_offset=(33,2),flip_y=True)
            #self.unwrapped._initialize_renderer(origin_offset=(self.surface_width // 2, -self.surface_height // 2),flip_y=True)
        if self.running_simulation:
            self.analyse_moieties = AnalyseMoietySimulation(self.molecules, self.goals, self.obstacles)
            self.device = self.simulator

        self._get_matching()
        self.current_episode = self.get_latest_episode()+1
        self.current_step = 0

        self.center_of_scan_position = self.current_molecule.center
        self._get_current_stm_state(reset=True)
        self._update_current_moiety()

        if self.renderer:
            self.render(None)
            self.renderer.update()

        obs = self.observation()
        return obs, info

    def observation(self):
        (moiety_types,
         moiety_positions_nm,
         moiety_orientations_rad,
         moiety_bbox_width_height_nm,
         confidence,
         moiety_target_contours_nm,
         moiety_matching_reference_contours_nm,
         moiety_types_before,
         moiety_type_changes
         ) = self.analyse_moieties.get_observation()

        assert len(moiety_types) == len(self.molecules), f"Number of molecules in observation {len(moiety_types)} does not match number of molecules in environment {len(self.molecules)}"
        # for i, m in enumerate(self.molecules):
        #     m.name = moiety_types[i]
        #     m.center = Point(moiety_positions_nm[i])
        #     m.rotation = np.rad2deg(moiety_orientations_rad[i])


        #current_molecule = self.unwrapped.current_matching.molecule
        #self.current_molecule.name = moiety_types[self.active_moiety_index] # XXX this is redundant
        #self.current_molecule.set_type(moiety_types[self.active_moiety_index])
        # self.current_molecule.set_position(Point(moiety_positions_nm[self.active_moiety_index]))
        # self.current_molecule.set_orientation(np.rad2deg(moiety_orientations_rad[self.active_moiety_index]))
        # self.current_molecule._adsorption_geometries = self.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, "_adsorption_geometries")

        surface_orientation = 0 if not self.include_surface_orientation else self.current_molecule.rotation
        orientation_to_goal = self.__orientation_to_goal()
        direction_to_goal = self.__directional_vector_to_goal()
        angle_to_goal = self.__angle_to_goal()
        distance_to_goal = self.__distance_to_goal()


        if self.current_molecule.num_sensors > 0:
            sensor_readings = np.full((self.current_molecule.num_sensors,), np.inf)

            for j, sensor in enumerate(self.current_molecule.sensors):
                for other in self.unwrapped.molecules:
                    if self.current_molecule == other: # dont compare moiety with itself
                        continue
                    if sensor.intersects(other.polygon):
                        intersection = shapely.intersection(sensor, other.polygon)
                        distance = self.current_molecule.polygon.distance(intersection)
                        sensor_readings[j] = min(sensor_readings[j], distance)
        else:
            sensor_readings = np.zeros(self.unwrapped.num_sensors)

        obs = np.array(
            [surface_orientation, orientation_to_goal, angle_to_goal, distance_to_goal, moiety_type_changes, sensor_readings],
            dtype=np.object_
            )

        return self.flatten(obs)

        # sensor_readings = np.full((len(self.molecules), self.num_sensors), np.inf)
        # for i, m in enumerate(self.molecules):
        #    for j, s in enumerate(m.sensors):
        #        for n in self.molecules:
        #            if m == n: continue
        #            if s.intersects(n.polygon):
        #                intersection = shapely.intersection(s,n.polygon)
        #                sensor_readings[i][j] = np.min([sensor_readings[i][j], m.polygon.distance(intersection)])

        # return {
        #         "names":       tuple(m.name for m in self.molecules),
        #         "coordinates": np.array([[m.center.x, m.center.y] for m in self.molecules]),
        #         #"sensors":     sensor_readings,
        #         "rotations":   np.array([m.rotation for m in self.molecules])
        #         }


    def flatten(self, obs):
        flat = []
        for part in obs:
            if isinstance(part, (np.ndarray, list)):
                flat.extend(np.asarray(part, dtype=np.float32).flatten())
            else:
                flat.append(np.float32(part))
        return np.array(flat, dtype=np.float32)

    def __orientation_to_goal(self) -> float:
        """Calculates the orientation difference between the molecule and the goal in degrees.
        Returns:
            float: The difference in orientation between the molecule and the goal in degrees.
        """
        return self.current_matching.goal.rotation - self.current_molecule.rotation

    def __directional_vector_to_goal(self) -> NDArray:
        molecule_goal_vector = np.array([self.current_matching.goal.position.x - self.current_molecule.center.x, self.current_matching.goal.position.y - self.current_molecule.center.y])
        return molecule_goal_vector / np.linalg.norm(molecule_goal_vector)

    def __angle_to_goal(self) -> float:
        molecule_goal_vector = np.array([self.current_matching.goal.position.x - self.current_molecule.center.x, self.current_matching.goal.position.y - self.current_molecule.center.y])
        return np.arctan2(molecule_goal_vector[1], molecule_goal_vector[0]) - np.deg2rad(self.current_molecule.rotation)

    def __distance_to_goal(self) -> float:
        return np.linalg.norm(np.array([self.current_matching.goal.position.x - self.current_molecule.center.x, self.current_matching.goal.position.y - self.current_molecule.center.y]))

    def _get_current_stm_state(self, use_measured_data: bool = False, reset: bool = False, action_intersects_moiety: bool = False):
        info = dict()
        if self.running_simulation:
            return self.analyse_moieties.update_moiety_information(reset)[9]
        else:
            if use_measured_data:
                self.filename_per_timestep = self._get_next_measured_data()
                self.device.set_filename_per_timestep(self.filename_per_timestep)
                self.device.scan_topography(center_nm=self.center_of_scan_position, topography_size_nm=4, number_of_topography_points=128)
                self.device.current_image_file = self._get_next_measured_data()
            else:
                self.filename_per_timestep = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+'_'+str(self.experiment_name)+'_'+str(self.current_episode)+'_'+str(self.steps).zfill(4)
                self.device.set_filename_per_timestep(self.filename_per_timestep)
                #if action_intersects_moiety:
                info['perfect_scan'] = False
                while not info['perfect_scan']:
                    self.device.scan_topography(center_nm=self.center_of_scan_position, topography_size_nm=4, number_of_topography_points=128)
                    update_moiety, info = self.analyse_moieties.update_moiety_information(self.device.current_image_file, self.start_tip_position, self.end_tip_position, self.active_moiety_index)
                # else:
                #     info['perfect_scan'] = False
                #     while not info['perfect_scan']:
                #         self.device.scan_topography(center_nm=self.current_molecule.center, topography_size_nm=4, number_of_topography_points=128)
                #         update_moiety, info = self.analyse_moieties.update_moiety_information(self.device.current_image_file, self.start_tip_position, self.end_tip_position, self.active_moiety_index)
            return update_moiety, info

    def _set_center_of_scan_position(self, center):
        self.center_of_scan_position = center

    def _update_current_moiety(self) -> dict[Molecule | MoleculeExperiment, tuple[Point, Movement]]:
        if self.running_simulation: return {}
        new_position = Point(self.analyse_moieties.moiety_position_nm[self.active_moiety_index])
        new_orientation = int(np.round(np.rad2deg(self.analyse_moieties.moiety_orientation_rad[self.active_moiety_index]),0))
        new_orientation = find_nearest([0, 30, 60, 90], new_orientation % 90)
        if new_orientation == 90: new_orientation = 0
        # Move simulator
        logger.info(f"{self.current_molecule} \t {self.current_molecule.orientations}")

        movement = self.current_molecule.move_to(new_position, new_orientation)
        logger.info(f"{self.current_molecule} \t {self.current_molecule.previous_position} \t {self.current_molecule.previous_orientation}")

        moved_molecules = dict()
        moved_molecules[self.current_molecule] = (self.current_molecule.previous_position, movement)
        return moved_molecules

    def _get_next_measured_data(self):
        if self.running_simulation:
            return None
        # get all the sxm-files in measured_data_dir and remove dublicates
        measured_data_dir = os.path.normpath("D:/Measurement_Data/2025/05/FePc_Au111_4K/STM_data")
        # sort by name
        all_files = os.listdir(measured_data_dir)
        all_files.sort()
        all_files = [file for file in all_files if file.endswith(".sxm")]
        # remove overview
        all_files = [file for file in all_files if "overview" not in file]

        # get files of unique names
        unique_files = list()
        unique_files = [file for file in all_files if file not in unique_files]
        self.device.current_image_file = unique_files[self.unwrapped.steps]
        self.device.current_image_file = os.path.join(measured_data_dir, self.device.current_image_file)
        return self.device.current_image_file
