import numpy as np
import os
import time
import pickle
import yaml
import matplotlib.pyplot as plt
import gc

from shapely import Point
import cv2
from loguru import logger

from stm_remote_control.stm import CreatecSTM, NanonisSTM
from object_recognition.Yolo.object_detection import ObjectRecognition
from object_recognition.post_processing.analyse_moieties import AnalyseMoietyCreatec, AnalyseMoietyNanonis

class ActionSpaceMask:
    def __init__(self, hardware: str, moiety_type: str, type_of_measurement: str):
        """
        Initializes the action space mask evaluator.

        Parameters:
            z_range (tuple): Range for z tip height relative to setpoint (min, max).
            z_step (float): Step size for z movement.
            v_range (tuple): Range for bias voltage (min, max).
            v_step (float): Step size for bias voltage.
        """
        # Setup from yaml file
        self.input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))
        with open(self.input_yaml_file, "r") as file:  # Replace "config.yaml" with your actual file path
            yaml_data = yaml.safe_load(file)
        
        self.experiment_name = yaml_data["experiment_name"]
        self.main_directory = yaml_data["main_dir"]
        self.tip_formation = yaml_data["tip_formation"]
        self.directory_for_stm_data = os.path.normpath(os.path.join(self.main_directory, "STM_data"))        
        os.makedirs(self.directory_for_stm_data, exist_ok=True)

        # Logging setup
        log_dir = os.path.join(os.getcwd(), "log_experiments")
        os.makedirs(log_dir, exist_ok=True)
        log_file = os.path.join(log_dir, self.experiment_name+'.log')

        logger.add(log_file, level="INFO")

        self.type_of_measurement = type_of_measurement
        assert self.type_of_measurement in ["translation", "orientation"], f"Type of measurement {self.type_of_measurement} not supported"

        # Get the latest overview image from the directory
        if hardware == "createc":
            endswith = '.dat'
        elif hardware == "nanonis":
            endswith = '.sxm'

        
        overview_image = os.path.normpath(sorted(
            [f for f in os.listdir(self.directory_for_stm_data) if  'Au(111)-FePc-5K' in f and f.endswith(endswith)],
            key=lambda f: os.path.getmtime(os.path.join(self.directory_for_stm_data, f))
            )[-1])
        self.overview_path = os.path.normpath(os.path.join(self.directory_for_stm_data, overview_image))
        
        # Initialize hardware and analysis routines
        self.hardware = hardware
        if hardware == "createc":
            self.stm = CreatecSTM()
            self.analyse_moieties = AnalyseMoietyCreatec()
            self.stm.connect_to_stm()
        elif hardware == "nanonis":
            self.stm = NanonisSTM()
            self.analyse_moieties = AnalyseMoietyNanonis()

        self.translation_threshold = self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, "asm_translation_threshold_nm")
        self.rotation_threshold = self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, "asm_rotation_threshold_deg")

        # Get moiety size from yaml file
        self.max_moiety_size_nm = self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, "max_moiety_size_nm")
        self.min_moiety_size_nm = self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, "min_moiety_size_nm")

        # Create parameter grids for z and v with independent step sizes
        z_range_nm = [0,1]# [self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, type_of_measurement, "min_z_nm"), self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, type_of_measurement, "max_z_nm")]
        z_step_nm = 0.05#self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, type_of_measurement, "step_z_nm")
        v_range_mV = [-2000,2000]#[self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, type_of_measurement, "min_v_mV"), self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, type_of_measurement, "max_v_mV")]
        v_step_mV = 250# self.analyse_moieties.get_yaml_data(os.path.join(os.getcwd(), "input.yaml"), moiety_type, type_of_measurement, "step_v_mV")
        self.z_values_nm = np.round(np.arange(z_range_nm[0], z_range_nm[1] + z_step_nm, z_step_nm),7)
        self.v_values_mV = np.arange(v_range_mV[0], v_range_mV[1] + v_step_mV, v_step_mV)

    def init_object(self, numbering=True, plot=True):
        """
        Detects the object in the overview image and lets the user select the object
        to be manipulated.
        """
        (self.moiety_type, self.moiety_position_nm, self.moiety_orientation_rad, self.moiety_bbox_width_height_nm, self.confidence) = \
            self.analyse_moieties.init_moiety_information(self.overview_path, numbering=numbering, plot=plot)
        # If more than one object is detected, ask the user which to manipulate.
        if len(self.moiety_type) != 1:
            moiety_index = int(input("Enter the index of the object to manipulate: "))
            self.analyse_moieties._set_current_moiety(moiety_index)
        else:
            self.analyse_moieties._set_current_moiety(0)
        self.moiety_type = self.analyse_moieties.moiety_types[self.analyse_moieties._current_moiety]
        self.moiety_position_nm = self.analyse_moieties.moiety_position_nm[self.analyse_moieties._current_moiety]
        self.moiety_orientation_rad = self.analyse_moieties.moiety_orientation_rad[self.analyse_moieties._current_moiety]
        self.moiety_bbox_width_height_nm = self.analyse_moieties.moiety_bbox_width_height_nm[self.analyse_moieties._current_moiety]
        self.confidence = self.analyse_moieties.confidence[self.analyse_moieties._current_moiety]

        # Only allow to input R, L, U, D
        valid_inputs = {'R', 'L', 'U', 'D'}

        while True:
            self.set_end_position = input("Enter a direction to initialize the end position of manipulation (R, L, U, D): ").strip().upper()
            # XXX Testing ...
            # print("Warning: in ASM -> init_object() -> self.set_end_position = 'D'")
            # self.set_end_position = 'D'
            # XXX Testing ...
            if self.set_end_position in valid_inputs:
                print(f"You entered: {self.set_end_position}")
                break  # Exit the loop if input is valid
            else:
                print("Invalid input! Please enter only R, L, U, or D.")

        # Set start and end position based on the user input
        if self.type_of_measurement == "translation":
            manipulation_distance_nm = 5
        elif self.type_of_measurement == "orientation":
            manipulation_distance_nm = 4

        self.init_start_position = self.moiety_position_nm
        if self.set_end_position == 'R':
            self.init_end_position = self.init_start_position + np.array([manipulation_distance_nm, 0])
        elif self.set_end_position == 'L':
            self.init_end_position = self.init_start_position + np.array([-manipulation_distance_nm, 0])
        elif self.set_end_position == 'U':
            self.init_end_position = self.init_start_position + np.array([0, manipulation_distance_nm])
        elif self.set_end_position == 'D':
            self.init_end_position = self.init_start_position + np.array([0, -manipulation_distance_nm])

    def detect_object(self, center_nm):
        """
        Detects an object and determines its initial position and orientation.
        A filename is generated based on the current manipulation parameters.
        """
        self.stm.set_filename_per_timestep(filename_per_timestep=self.filename_per_timestep)

        # Scan topography based on the object's current information
        topography_size_nm = self.max_moiety_size_nm * 2
        self.stm.scan_topography(
            center_nm=center_nm,
            topography_size_nm=topography_size_nm,
            number_of_topography_points=128,
        )
        
        image_path = os.path.normpath(os.path.join(self.directory_for_stm_data, self.filename_per_timestep))
        print(image_path)

        # Get the new STM image and perform object detection
        self.analyse_moieties.update_moiety_information(image_path, active_moiety_index=self.analyse_moieties._current_moiety)
        (
            self.moiety_type, 
            self.moiety_position_nm, 
            self.moiety_orientation_rad, 
            moiety_bbox_width_height_nm, 
            self.confidence,
            self.moiety_target_contour_nm, 
            self.moiety_matching_reference_contour_nm, _, _
        ) = self.analyse_moieties.get_active_moiety_information()

        self.analyse_moieties.plot_moieties(self.overview_path, plot=False)

        print("Analyse Moieties finished")
        return self.moiety_type, self.moiety_position_nm, self.moiety_orientation_rad, self.moiety_matching_reference_contour_nm, self.confidence

    @staticmethod
    def rotate_around_origin(origin, point, angle_rad):
        """
        Rotate a point counterclockwise by a given angle around a given origin.

        The angle should be given in radians.
        """
        ox, oy = origin
        px, py = point

        qx = ox + np.cos(angle_rad) * (px - ox) - np.sin(angle_rad) * (py - oy)
        qy = oy + np.sin(angle_rad) * (px - ox) + np.cos(angle_rad) * (py - oy)
        return qx, qy

    def manipulate_object(self, z, v):
        """
        Manipulates an object by adjusting the approach distance (z) and bias voltage (v).

        Parameters:
            molecule_pos (tuple): Initial (x, y) position.
            molecule_orientation (float): Initial orientation in radians.
            manipulation_params (dict): Contains movement parameters {z, v}.

        Returns:
            dict: Resulting state with keys:
                  - z_modified: True if the approach distance is changed (z != 0)
                  - rotated: True if the orientation is modified
                  - destroyed: True if z is negative (e.g., tip crashing)
                  - new_position: (x, y) position (unchanged in this routine)
                  - new_orientation: Updated orientation
                  - z, v: The manipulation parameters used
        """
        
        # Retrieve current molecule position and orientation (in nm and rad)
        molecule_pos = self.moiety_position_nm      # e.g., [x, y]
        molecule_orientation = self.moiety_orientation_rad

        # Compute the lateral manipulation vector based on user-specified start and end positions
        lateral_manipulation_vector_nm = self.init_end_position - self.init_start_position

        # Compute candidate lateral end positions by rotating the vector
        lateral_manipulation_end_0 = molecule_pos + self.rotate_around_origin(
            origin=[0, 0],
            point=lateral_manipulation_vector_nm,
            angle_rad=molecule_orientation
        )
        lateral_manipulation_end_90 = molecule_pos + self.rotate_around_origin(
            origin=[0, 0],
            point=lateral_manipulation_vector_nm,
            angle_rad=molecule_orientation + np.pi/2
        )
        lateral_manipulation_end_180 = molecule_pos + self.rotate_around_origin(
            origin=[0, 0],
            point=lateral_manipulation_vector_nm,
            angle_rad=molecule_orientation + np.pi
        )
        lateral_manipulation_end_neg90 = molecule_pos + self.rotate_around_origin(
            origin=[0, 0],
            point=lateral_manipulation_vector_nm,
            angle_rad=molecule_orientation - np.pi/2
        )

        # Determine whether the molecule is closer to the initial start or end positions
        distance_from_start = np.linalg.norm(molecule_pos - self.init_start_position)
        distance_from_end = np.linalg.norm(molecule_pos - self.init_end_position)
        closer_to_start = distance_from_start <= distance_from_end

        if closer_to_start:
            # Compute candidate distances from the intended end position
            d0 = np.linalg.norm(lateral_manipulation_end_0 - self.init_end_position)
            d90 = np.linalg.norm(lateral_manipulation_end_90 - self.init_end_position)
            d180 = np.linalg.norm(lateral_manipulation_end_180 - self.init_end_position)
            dneg90 = np.linalg.norm(lateral_manipulation_end_neg90 - self.init_end_position)
            candidate_distances = [d0, d90, d180, dneg90]
        else:
            # Compute candidate distances from the intended start position
            d0 = np.linalg.norm(lateral_manipulation_end_0 - self.init_start_position)
            d90 = np.linalg.norm(lateral_manipulation_end_90 - self.init_start_position)
            d180 = np.linalg.norm(lateral_manipulation_end_180 - self.init_start_position)
            dneg90 = np.linalg.norm(lateral_manipulation_end_neg90 - self.init_start_position)
            candidate_distances = [d0, d90, d180, dneg90]

        candidates = [lateral_manipulation_end_0, lateral_manipulation_end_90,
                      lateral_manipulation_end_180, lateral_manipulation_end_neg90]
        self.lateral_manipulation_end = candidates[np.argmin(candidate_distances)]

        # Determine the lateral manipulation start position.
        # Compute the vector from the molecule center to the selected end position.
        dx = self.lateral_manipulation_end[0] - molecule_pos[0]
        dy = self.lateral_manipulation_end[1] - molecule_pos[1]
        # The start position is located at a distance equal to 1.25 times half the max moiety size away from the molecule center
        
        # Calculate the angle in the reverse direction (i.e. opposite to [dx, dy])
        if self.type_of_measurement == "translation":
            # Manipulation vector goes through the molecule center
            half_max_moiety_size_nm = self.max_moiety_size_nm/2
            length_from_start = half_max_moiety_size_nm * 1.5
            angle_in_reverse = np.arctan2(-dy, -dx)
            lateral_manipulation_start_x = length_from_start * np.cos(angle_in_reverse) + molecule_pos[0]
            lateral_manipulation_start_y = length_from_start * np.sin(angle_in_reverse) + molecule_pos[1]
            self.lateral_manipulation_start = np.array([lateral_manipulation_start_x, lateral_manipulation_start_y])
        elif self.type_of_measurement == "orientation":
            # Manipulation vector is off center by 2/3 of the min moiety size
            length_from_start = self.max_moiety_size_nm/2 + np.linalg.norm([dx, dy])
            angle_in_reverse = np.arctan2(-dy, -dx)
            angle_to_offset = np.arctan2(self.min_moiety_size_nm/3,np.linalg.norm([dx, dy]))
            angle_in_reverse += angle_to_offset
            lateral_manipulation_start_x = length_from_start * np.cos(angle_in_reverse) + self.lateral_manipulation_end[0] # start position is given by the end position due to the manipulation angle
            lateral_manipulation_start_y = length_from_start * np.sin(angle_in_reverse) + self.lateral_manipulation_end[1]
            self.lateral_manipulation_start = np.array([lateral_manipulation_start_x, lateral_manipulation_start_y])

        self.stm.perform_lateral_manipulation(
            x_start_position_nm=self.lateral_manipulation_start[0],
            y_start_position_nm=self.lateral_manipulation_start[1],
            x_end_position_nm=self.lateral_manipulation_end[0],
            y_end_position_nm=self.lateral_manipulation_end[1],
            z_position_nm=z,
            voltage_mV=v
        )


    def measure_action_space_mask(self, conf_threshold=0.4, user_input=True):
        """
        Sweeps through the parameter space of z and v and records the action results.
        For each combination, the object is detected and then manipulated.
        """
        # Moiety information before manipulation
        time_stamp = time.strftime("%Y%m%d-%H%M%S")
        self.filename_per_timestep = f"{time_stamp}_ASM_{self.moiety_type}_{self.type_of_measurement}_zX_vX"
        # if self.hardware == "createc":
        #     self.filename_per_timestep+= ".dat"
        # elif self.hardware == "nanonis":
        #     self.filename_per_timestep+= ".sxm"
        moiety_type_before, moiety_position_before, moiety_orientation_rad_before, moiety_contour_before, conf = self.detect_object(center_nm = Point(self.moiety_position_nm))
        # # XXX Testing
        # moiety_position_before = self.moiety_position_nm
        # moiety_orientation_rad_before = 0
        # self.moiety_orientation_rad = moiety_orientation_rad_before
        # moiety_contour_before = np.array([0.1,0.5])
        # moiety_type_before = self.moiety_type
        # i =0
        # # XXX Testing ...

        # Load latest pickled data if available
        files = [f for f in os.listdir(self.directory_for_stm_data) if '_asm_' in f.lower()]
        # Filter for the specific type of measurement
        files = [f for f in files if self.type_of_measurement in f.lower()]
        # Filter files end with .pkl
        files = [f for f in files if f.endswith('.pkl')]
        if files:
            latest_pickled_data = sorted(
                files,
                key=lambda f: os.path.getmtime(os.path.join(self.directory_for_stm_data, f))
            )[-1]
            full_path = os.path.join(self.directory_for_stm_data, latest_pickled_data)
            if os.path.exists(full_path):
                with open(full_path, "rb") as f:
                    measured_data = pickle.load(f)
            else:
                measured_data = []
        else:
            measured_data = []

        molecule_changed = False
        molecule_destroyed = False

        # Initialize skip_conditions based on previous destructive measurements
        skip_conditions = set()
        for entry in measured_data:
            if entry.get("molecule_destroyed"):
                z_prev = entry["z"]
                v_prev = entry["v"]
                skip_conditions.add((z_prev, v_prev))

        # measure bias voltage from low to high and alternate between positive and negative
        asm_v_values_mV = sorted(self.v_values_mV , key=abs)
        for z in self.z_values_nm:
            for v in asm_v_values_mV:

                print(f"Manipulating object at z={z} nm, v={v} mV")

                # Skip if z and v is already measured
                if any([d["z"] == z and d["v"] == v for d in measured_data]):
                    continue

                # Check if we should skip due to previous destruction at this or a more conservative parameter
                if any(
                    z_skip <= z and (
                        (v_skip >= 0 and v >= v_skip) or
                        (v_skip <= 0 and v <= v_skip)
                    )
                    for (z_skip, v_skip) in skip_conditions
                ):
                    print(f"Skipping z={z} nm, v={v} mV due to prior destructive manipulation.")
                    continue                
                
                # Moiety information after manipulation
                time_stamp = time.strftime("%Y%m%d-%H%M%S")
                self.filename_per_timestep = f"{time_stamp}_ASM_{self.moiety_type}_{self.type_of_measurement}_z{z}_v{v}"
                if self.hardware == "createc":
                    self.filename_per_timestep+= ".dat"
                elif self.hardware == "nanonis":
                    self.filename_per_timestep+= ".sxm"
                
                # If user input is enabled add data to measured data
                self.input_yaml_file = os.path.normpath(os.path.join(os.getcwd(),"input.yaml"))
                with open(self.input_yaml_file, "r") as file:  # Replace "config.yaml" with your actual file path
                    yaml_data = yaml.safe_load(file)
        
                user_input = yaml_data["user_input_asm"]
                if user_input:
                    # Ask user if they want to skip this combination
                    user_input_skip = input(f"Do you want to skip z={z} nm, v={v} mV? (y/n): ").strip().lower()
                    if user_input_skip == 'y':
                        measured_data.append({
                            "z": z,
                            "v": v,

                            "lat_manip_start": np.array([0,0]),
                            "lat_manip_end": np.array([0,0]),

                            "moiety_type_before": moiety_type_before,
                            "moiety_type_after": 'defect - user input',
                            "moiety_position_before": moiety_position_before,
                            "moiety_position_after": moiety_position_before,
                            "moiety_orientation_before": moiety_orientation_rad_before,
                            "moiety_orientation_after": moiety_orientation_rad_before,
                            "moiety_contour_before": moiety_contour_before,
                            "moiety_contour_after": moiety_contour_before,
                            "confidence": conf,

                            "translation": 0,
                            "rotation": 0,
                            "molecule_changed": True,
                            "molecule_destroyed": True,

                            "tip formation": self.tip_formation
                        })
                        print(f"Skipping z={z} nm, v={v} mV")
                        # Save the measured_data for each combination:
                        pickle_filename = os.path.join(self.directory_for_stm_data, f"{self.filename_per_timestep}.pkl")
                        with open(pickle_filename, "wb") as f:
                            pickle.dump(measured_data, f)
                        print(f"Saved measured data to {pickle_filename}")
                        continue
                                
                # For each parameter combination, perform object detection and manipulation
                self.manipulate_object(z, v)

                
                moiety_type, moiety_position, moiety_orientation_rad, moiety_contour, conf = self.detect_object(center_nm = self.stm._get_rough_lat_position_for_exact_search())
                print("Before manipulation:")
                print(f"Moiety type: {moiety_type_before}, position: {moiety_position_before}, orientation: {np.rad2deg(moiety_orientation_rad_before)}")
                print("After manipulation:")
                print(f"Moiety type: {moiety_type}, position: {moiety_position}, orientation: {np.rad2deg(moiety_orientation_rad)}, confidence: {conf}")
                
                # # XXX Testing
                # moiety_type = self.moiety_type
                # translation = np.array([-5,0])
                # rotation = np.deg2rad(-30)
                # moiety_position = moiety_position_before + (-1)**i * translation
                # i+=1
                # moiety_orientation_rad = moiety_orientation_rad_before + np.sign(rotation) * rotation
                # self.moiety_orientation_rad = moiety_orientation_rad
                # moiety_contour=np.array([0.1,0.5])
                # conf=0.41
                # # XXX Testing

                # Determine outcome after manipulation
                translation = np.linalg.norm(np.array(moiety_position) - np.array(moiety_position_before))
                rotation = np.abs(moiety_orientation_rad - moiety_orientation_rad_before)

                # Check if the moiety type has changed. Requires user input to confirm
                if (moiety_type != moiety_type_before) or (conf < conf_threshold) or (conf == 1.0):
                    confirmation_done = False
                    while not confirmation_done:
                        print(f"Moiety type changed from {moiety_type_before} to {moiety_type} or confidence is low ({conf:.2f}).")
                        user_input = input("What happend to the moiety? (1: nothing / 2: changed / 3: destroyed): ").strip().lower()
                        if user_input == '1':
                            confirmation_done = True
                            molecule_changed = False
                            molecule_destroyed = False
                            moiety_type = moiety_type_before

                        elif user_input == '2':
                            confirmation_done = True
                            molecule_changed = True
                            molecule_destroyed = False
                            moiety_type = moiety_type

                        elif user_input == '3':
                            confirmation_done = True
                            molecule_changed = False
                            molecule_destroyed = True
                            moiety_type = moiety_type_before
                            # Add skip condition
                            skip_conditions.add((z, v))
                        else:
                            print("Invalid input! Please enter 1, 2, or 3.")
                
                # Plot the molecule position, stm tip positions (start and end), and the manipulation vector
                import matplotlib.pyplot as plt
                plt.figure(figsize=(10, 10))
                plt.scatter(moiety_position_before[0], moiety_position_before[1], c='blue', label='Moiety Position Before')
                plt.scatter(moiety_position[0], moiety_position[1], c='k', label='Moiety Position After')
                plt.scatter(self.init_start_position[0], self.init_start_position[1], c='red', marker='x', s=100, label='Initial Start Position')
                plt.scatter(self.init_end_position[0], self.init_end_position[1], c='red', marker='x', s=100, label='Initial End Position')
                plt.scatter(self.lateral_manipulation_start[0], self.lateral_manipulation_start[1], c='green', label='Manipulation Start Position')
                plt.scatter(self.lateral_manipulation_end[0], self.lateral_manipulation_end[1], c='orange', label='Manipulation End Position')
                plt.quiver(self.lateral_manipulation_start[0], self.lateral_manipulation_start[1],
                        self.lateral_manipulation_end[0] - self.lateral_manipulation_start[0],
                            self.lateral_manipulation_end[1] - self.lateral_manipulation_start[1],
                            angles='xy', scale_units='xy', scale=1, color='purple', label='Manipulation Vector')
                # Plot molecule orientation before and after manipulation
                plt.quiver(moiety_position_before[0], moiety_position_before[1],
                        np.cos(moiety_orientation_rad_before), np.sin(moiety_orientation_rad_before),
                            angles='xy', scale_units='xy', scale=1, color='blue', label='Molecule Orientation')
                plt.quiver( moiety_position[0], moiety_position[1],
                            np.cos(moiety_orientation_rad), np.sin(moiety_orientation_rad),
                            angles='xy', scale_units='xy', scale=1, color='k', label='Molecule Orientation After')
                # limits around origin of overview image
    
                plt.legend()
                plt.xlim(-15+self.init_start_position[0], 15+self.init_start_position[0])
                plt.ylim(-15+self.init_start_position[1], 15+self.init_start_position[1])
                
                plt.savefig(os.path.join(self.directory_for_stm_data, f"asm_manipulation_{self.filename_per_timestep}.png"))
                plt.close()
                gc.collect()

                # Save moiety information as dict i.e., the position before the manipulation, after the manipulation, the orientation before and after the manipulation, the confidence of the prediction
                # and the manipulation parameters z and v and the outcome that is 0: nothing, 1: movement, 2: rotation, 3: molecule changed 4: picked up or destroyed
                measured_data.append({
                    "z": z,
                    "v": v,

                    "lat_manip_start": self.lateral_manipulation_start,
                    "lat_manip_end": self.lateral_manipulation_end,

                    "moiety_type_before": moiety_type_before,
                    "moiety_type_after": moiety_type,
                    "moiety_position_before": moiety_position_before,
                    "moiety_position_after": moiety_position,
                    "moiety_orientation_before": moiety_orientation_rad_before,
                    "moiety_orientation_after": moiety_orientation_rad,
                    "moiety_contour_before": moiety_contour_before,
                    "moiety_contour_after": moiety_contour,
                    "confidence": conf,
                    
                    "translation": translation,
                    "rotation": rotation,
                    "molecule_changed": molecule_changed,
                    "molecule_destroyed": molecule_destroyed,

                    "tip formation": self.tip_formation
                })

                # Save the measured_data for each combination:
                pickle_filename = os.path.join(self.directory_for_stm_data, f"{self.filename_per_timestep}.pkl")
                with open(pickle_filename, "wb") as f:
                    pickle.dump(measured_data, f)
                print(f"Saved measured data to {pickle_filename}")

                # Set moiety information before manipulation to after manipulation
                moiety_type_before = moiety_type
                moiety_position_before = moiety_position
                moiety_orientation_rad_before = moiety_orientation_rad
                moiety_contour_before = moiety_contour

                self.moiety_type = moiety_type
                self.moiety_position_nm = moiety_position
                self.moiety_orientation_rad = moiety_orientation_rad
                self.moiety_contour = moiety_contour

    def plot_action_space_mask(self, moiety_type: str, type_of_measurement: str):
        """
        Plot a mask for the action space based on the measured data. The mask depends on the type of moiety and the performed task.
        """

        # Load the measured data
        files = [f for f in os.listdir(self.directory_for_stm_data) if str(moiety_type) and '_asm_' in f.lower()]
        # Filter for the specific type of measurement
        files = [f for f in files if moiety_type in f and type_of_measurement in f]
        # Filter files end with .pkl
        files = [f for f in files if f.endswith('.pkl')]
        if files:
            latest_pickled_data = sorted(
                files,
                key=lambda f: os.path.getmtime(os.path.join(self.directory_for_stm_data, f))
            )[-1]
            full_path = os.path.join(self.directory_for_stm_data, latest_pickled_data)
            if os.path.exists(full_path):
                with open(full_path, "rb") as f:
                    measured_data = pickle.load(f)
            else:
                raise FileNotFoundError(f"File {full_path} not found.")
        else:
            raise FileNotFoundError("No pickled data found.")

        # Create a mask based on the filtered data
        mask = np.zeros((len(self.z_values_nm), len(self.v_values_mV)), dtype=int)

        # Set the threshold for the performed task
        if type_of_measurement == "translation":
            threshold = self.translation_threshold
            measured_entity = type_of_measurement
            # threshold = np.deg2rad(self.rotation_threshold)
            # measured_entity = 'rotation' # due to renaming
        elif type_of_measurement == "orientation":
            threshold = np.deg2rad(self.rotation_threshold)
            measured_entity = 'rotation' # due to renaming
            threshold = self.translation_threshold
            measured_entity = "translation"
        else:
            raise ValueError(f"Invalid performed task: {type_of_measurement}")

        # Generate the mask based on the measured data
        # Iterate through the measured data and update the mask
        for i, z in enumerate(self.z_values_nm):
            for j, v in enumerate(self.v_values_mV):
                # Check if the current z and v values are in the measured data
                for data in measured_data:
                    if round(data["z"],7) == z and data["v"] == v:
                        # Update the mask based on the outcome
                        if data["molecule_changed"]:
                            mask[i, j] = 1
                        elif data["molecule_destroyed"]:
                            mask[i, j] = 2
                        elif abs(data[str(measured_entity)]) > threshold:
                            mask[i, j] = 3
                        elif abs(data[str(measured_entity)]) <= threshold:
                            mask[i, j] = 4


                        
        # Plot the mask and save it
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(10, 6))
        ax = fig.add_subplot(111)

        # iterate through the z and v values
        # and plot the mask
        label_pickup = False
        label_change = False
        label_unchange = False
        for i, z in enumerate(self.z_values_nm):
            for j, v in enumerate(self.v_values_mV):
                # set color based on the mask value
                color = 'transparent'
                if mask[i, j] == 1:
                    color = 'grey'
                    label = ('pick up' if label_pickup == False else '_')
                    label_pickup = True
                elif mask[i, j] == 2:
                    color = 'grey'
                    label = ('pick up' if label_pickup == False else '_')
                    label_pickup = True
                elif mask[i, j] == 3:
                    color = 'fuchsia'
                    label = ('manipulation' if label_change == False else '_')
                    label_change = True
                elif mask[i, j] == 4:
                    color = 'cyan'
                    label = ('no change' if label_unchange == False else '_')
                    label_unchange = True
                if color != 'transparent':
                    plt.scatter(z, v, c=color, marker='o', alpha=0.8, label=label)
        
        #plt.colorbar(label='Mask Value')
        plt.title(f"Action Space Mask for {moiety_type} - {measured_entity}")
        plt.ylabel('bias voltage (mV)')
        plt.xlabel('approach distance (nm)')
        plt.xlim(self.z_values_nm[0]-0.03, 1)
        # set x-ticks to 0.2 nm
        plt.xticks(np.linspace(0,1,int(1.1/0.1)))
        plt.ylim(self.v_values_mV[0]-50, self.v_values_mV[-1]+50)
        ax.spines['bottom'].set_color('white')
        ax.spines['left'].set_color('white')
        ax.spines['right'].set_color('white')
        ax.spines['top'].set_color('white')
        ax.xaxis.label.set_color('white')
        ax.yaxis.label.set_color('white')
        ax.title.set_color('white')
        ax.tick_params(axis='x', colors='white')
        ax.tick_params(axis='y', colors='white')
        ax.legend()
        #plt.show()
        plt.savefig(os.path.join(self.directory_for_stm_data, f"action_space_mask_{moiety_type}_type_of_measurement_{type_of_measurement}_{measured_entity}.png"), transparent=True)



        return mask