import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import numpy as np
from molecule_movement.parsing.CircularMockUpData import CircularMockUpData
from skimage.metrics import structural_similarity as ssim
import scipy
import time

from .analyse_moieties import AnalyseMoieties

from molecule_movement import Molecule, MoleculeExperiment, Goal, Obstacle
from molecule_movement.parsing import CircularMockUpData
from molecule_movement.shapes import RECTANGLE, resize, DDNB

from loguru import logger

from shapely import Point
import shapely

class AnalyseMoietySimulation(AnalyseMoieties):
    def __init__(self, molecules: list[Molecule] | list[MoleculeExperiment], goals: list[Goal], obstacles: list[Obstacle]):
        super().__init__()

        self.molecules = molecules

        self.moiety_types = ["molecule_FePc_Au111_4K"]
        #self.moiety_position_nm = [Point(-3,-32)]
        self.moiety_position_nm = [Point(0,0)]

        self._obstacles_nm = [np.array([[    -32.769,     -2.5253],
       [    -26.441,     -2.5253],
       [    -32.612,     -11.588]]), np.array([[    -20.425,      -6.744],
       [    -14.722,     -7.8378],
       [    -14.019,     -13.385],
       [    -19.566,     -17.057],
       [    -25.737,     -13.385],
       [    -24.019,     -10.416],
       [    -22.456,     -9.9471]]), np.array([[    -32.769,     -20.885],
       [    -28.159,       -24.4],
       [    -31.753,     -26.197],
       [    -32.769,     -25.572]]), np.array([[     2.3876,     -13.619],
       [     4.9657,     -13.541],
       [     5.8251,      -15.26],
       [      4.497,      -16.51],
       [     1.5282,     -15.572]]), np.array([[     7.0751,     -21.119],
       [    -6.4405,     -42.369],
       [     7.1532,     -42.369]])]

        #self.molecules = molecules
        #self.moiety_types = np.array(["molecule" for i in molecules])
        #c = CircularMockUpData(dimensions_x=(-1,1), dimensions_y=(-1,1), step_x=1, step_y=1).get_molecular_data()
        #self.molecules = np.array([Molecule(center=Point(0,0),
        #                                    shape=Point(0,0).buffer(1),
        #                                    rotation=0,
        #                                    stochastic_updates=c,
        #                                    num_sensors=0,
        #                                    name="testing")])


    def load_reference_contour(self, moiety_type, at_origin=False):
        return np.array([resize(DDNB, 1,1).exterior.coords])

    def get_stm_image(self, img_file):
        if img_file.endswith(".dat"):
            img_file = img_file.replace(".dat", ".dat.jpeg")
        # Add .dat.jpeg to the image file
        elif not img_file.endswith(".dat.jpeg"):
            img_file = img_file + ".dat.jpeg"

        # Get number of pixels from dat file
        dat_file = self.get_datafile_for_image(img_file)
        pixels = int(self.get_data_for_image(dat_file, "Num.X / Num.X"))

        # Check if the image file exists for the given path
        if not os.path.exists(img_file):
            raise ValueError(f"File not found: {img_file}")

        # Load the STM image
        img = cv2.imread(img_file, cv2.IMREAD_COLOR)
        return img, pixels

    @staticmethod
    def remove_background(image):

        # Create an initial mask of the same size as the image, filled with zeros.
        mask = np.zeros(image.shape[:2], np.uint8)

        # Allocate temporary arrays for the background and foreground models.
        bgdModel = np.zeros((1, 65), np.float64)
        fgdModel = np.zeros((1, 65), np.float64)

        # Let the user select a ROI (Region of Interest) that encloses the foreground.
        # This interactive step ensures that the algorithm "dynamically" knows which region to keep.
        print("Select the region of interest (ROI) for the foreground object and press ENTER or SPACE when done.")
        rect = cv2.selectROI("Input Image", image, fromCenter=False, showCrosshair=True)
        cv2.destroyWindow("Input Image")

        # If no ROI was selected, exit the function.
        if rect == (0, 0, 0, 0):
            print("No ROI selected. Exiting.")
            return

        # Apply the GrabCut algorithm with the selected rectangle.
        # The number of iterations (here 5) can be adjusted depending on image complexity.
        cv2.grabCut(image, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)

        # Create a binary mask where sure or likely background pixels are 0 and foreground pixels are 1.
        mask2 = np.where((mask == cv2.GC_BGD) | (mask == cv2.GC_PR_BGD), 0, 1).astype('uint8')

        # Multiply the original image with the binary mask to extract the foreground.
        image_foreground = image * mask2[:, :, np.newaxis]

        # Save the output image (background removed)
        # cv2.imwrite(output_path, image_foreground)
        # print(f"Background removed image saved as {output_path}")

        # Optionally, display the result.
        cv2.imshow("Foreground", image_foreground)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        return image_foreground



    def get_moiety_information_px(self, reference_path, target_path, moiety_type, bbox, plot=False):
        """
        Get the moiety's pixel-level position and orientation by matching a reference contour
        to the detected target region in a microscopy image using chamfer matching.

        Parameters:
        -----------
        ref_img : str
            Path to the reference image file.
        img_file : str
            Path to the target image file.
        moiety_type : str
            Type of moiety (e.g., 'molecule').
        bbox : list[float]
            Bounding box around predicted object (center_x, center_y, half_width, half_height) in normalized coordinates.
        plot : bool
            Whether to visualize matching results.

        Returns:
        --------
        best_matching_position : np.ndarray
            [x, y] coordinates of moiety position in pixels.
        best_matching_orientation_rad : float
            Orientation in radians.
        best_target_contour : np.ndarray
            Target contour points.
        best_matching_reference_contour : np.ndarray
            Transformed reference contour points.
        """
        # === IMAGE LOADING & PREPROCESSING ========================================
        ref_img_gray = cv2.imread(reference_path, cv2.IMREAD_GRAYSCALE)
        tgt_img_gray = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)

        # Ensure that the reference and target images have the same size
        dat_reference_path = reference_path.replace(".jpeg", ".dat")
        if not os.path.exists(dat_reference_path):
            dat_reference_path = reference_path.replace(".dat.jpeg", ".dat")
        dat_target_path = target_path.replace(".jpeg", ".dat")
        if not os.path.exists(dat_target_path):
            dat_target_path = target_path.replace(".dat.jpeg", ".dat")

        _, img_size_px, _, _ = self.get_data_for_image(dat_reference_path, "Num.X / Num.X")
        size_px = img_size_px['x']
        assert size_px == img_size_px['y'], "Non-square image dimensions not supported."

        return self.determine_moiety_information_px(ref_img_gray, tgt_img_gray, size_px, moiety_type, bbox, plot=True)

    def init_moiety_information(self, img_file, numbering=True, plot=False):
        """
        Initialize the information of the moieties based on the prediction of the object recognition.

        Parameters:
        -----------
        pred: list
            The prediciton of the object recognition:
            0: type of moiety
            1: x_center
            2: y_center
            3: width
            4: height
            5: confidence

        Returns:
        --------
        moiety_type: str
            List containing the type of the moieties
        moiety_positions: np.array
            List containing the positions of the moieties in pixel
        moiety_orientations: float
            List containing the orientations of the moieties in degrees
        moiety_bbox: np.array
            List containing the bounding boxes of the moieties in pixel
        """
        # Object detection prediction
        return None, None, None, None, None
        #return self.moiety_types, self.moiety_position_nm, self.moiety_orientation_rad, self.moiety_bbox_width_height_nm, self.confidence

    def get_moiety_indices(self, moiety_type, moiety_position_nm, position_threshold_nm=0.5):
        """
        Determine the index of the moiety based on the type and the bounding box.

        Parameters:
        -----------
        moiety_type: str
            The type of the moiety.
        bbox: np.array
            The bounding box of the moiety.

        Returns:
        --------
        moiety_index: int
            The index of the moiety.
        """
        # Determine the index of the moiety
        moiety_index_type = np.where((self.moiety_types == moiety_type))

        # Determine the index of the moiety based on the position in nm within a tolerance of 0.5 nm
        moiety_index_position = np.where(np.linalg.norm(self.moiety_position_nm - moiety_position_nm, axis=1) < position_threshold_nm)

        moiety_index = np.intersect1d(moiety_index_type, moiety_index_position)

        return moiety_index

    def get_obstacles(self):
        # Get indices where moiety_type is not molecule or atom
        return self.obstacles



    def update_moiety_information(self,reset:bool=False):
        """
        Determine the exact position of the predicted moiety in the image based on the respective reference image.
        Only for a moiety with prefix 'molecule_' or 'atom_' the exact position is determined. Other moieties
        are classified by the center of the bounding box.

        Parameters:
        -----------
        image: np.array
            The measured image containing the moieties.
        pred: list
            The prediciton of the object recognition:
            0: type of moiety
            1: x_center
            2: y_center
            3: width
            4: height
            5: confidence

        Returns:
        --------
        moiety_positions: np.array
            List containing the positions of the moieties in nm
        moiety_orientations: float
            List containing the orientations of the moieties in degrees
        moiety_type: str
            List containing the type of the moieties
        moiety_contour: np.array
            List containing the contours of the moiety
        """

        for m in self.molecules:
            m.move_to(shapely.affinity.translate(m.center, -1, -1), 0)
        self.moiety_types_before = self.moiety_types
        self.moiety_types = self.moiety_types
        self.moiety_type_changes = [False] * len(self.molecules)
        self.moiety_position_nm = [m.center for m in self.molecules]
        self.moiety_orientation_rad = [m.orientation for m in self.molecules]
        self.moiety_bbox_width_height_nm = [(m.size_x, m.size_y) for m in self.molecules]
        self.confidence = [1.0] * len(self.molecules)

        self.moiety_target_contour_nm = [Point(0,0).buffer(2)] * len(self.molecules)
        self.moiety_matching_reference_contour_nm = [Point(0,0).buffer(2)] * len(self.molecules)

        moved_molecules = dict()
        for moiety in self.molecules:
            movement = moiety.movement
            if movement.moved():
                moved_molecules[moiety] = (moiety.previous_position, movement)


        return self.moiety_types, self.moiety_position_nm, self.moiety_orientation_rad, self.moiety_bbox_width_height_nm, self.confidence, self.moiety_target_contour_nm, self.moiety_matching_reference_contour_nm, self.moiety_types_before, self.moiety_type_changes, moved_molecules

    def save_reference_image(self, img_file, moiety_type):
        """
        Save the current STM image as a new reference image if no reference image exists for the current moiety.
        If a reference image already exists, it is not overwritten.
        """
        # Get image
        img, img_size = self.get_stm_image(img_file)

        reference_filename = f"{moiety_type}_{img_size}.jpeg"
        reference_image_path = os.path.join(self.reference_image_dir, reference_filename)

        # Check if the reference image already exists
        if not os.path.exists(reference_image_path):
            cv2.imwrite(reference_image_path, img)

    def save_reference_contour(self, ref_contour_nm, moiety_type, plot=True):
        """ Saves the contour of the moiety as a reference contour file.
        """
        # Save file if it does not exist
        file = os.path.join(self.reference_image_dir, f"{moiety_type}_contour_shape.npy")
        if not os.path.exists(file):
            np.save(file, ref_contour_nm)

        # Plot the contour if requested
        if plot:
            moiety_matching_reference_contour_nm = np.load(file)
            fig = plt.figure(figsize=(10, 10))
            ax = fig.add_subplot(111)
            ax.set_title(f"Reference contour for {moiety_type}")
            ax.set_xlabel("x (nm)")
            ax.set_ylabel("y (nm)")
            ax.scatter(moiety_matching_reference_contour_nm.T[0], moiety_matching_reference_contour_nm.T[1], c='red', label='Matching Reference Contour')
            ax.legend()
            plt.show()
            print("Reference contour saved and plotted.")

    def get_center_of_image_nm(self, img_file):

        dat_file = self.get_datafile_for_image(img_file)
        global_image_offset_x_dac = -self.get_data_for_image(dat_file, 'Scanrotoffx / OffsetX')
        global_image_offset_y_dac = -self.get_data_for_image(dat_file, 'Scanrotoffy / OffsetY')
        real_img_size_y_px = self.get_data_for_image(dat_file, "Num.Y / Num.Y")
        real_img_size_y_dac = self.convert_pixel_to_dac(real_img_size_y_px, img_file=img_file)
        center_of_image = np.array([global_image_offset_x_dac, global_image_offset_y_dac + real_img_size_y_dac/2])
        return self.convert_dac_to_nm(center_of_image, img_file=img_file)


    def get_origin_of_image_nm(self, img_file):
        """
        Get the origin of the image in the global coordinate system.

        Parameters:
        -----------
        img_file: str
            The path to the image file.

        Returns:
        --------
        origin_position_of_overview_image: np.array
            The origin position of the image in the global coordinate system.
        """

        dat_file = self.get_datafile_for_image(img_file)
        global_image_offset_x_dac = -self.get_data_for_image(dat_file, 'Scanrotoffx / OffsetX')
        global_image_offset_y_dac = -self.get_data_for_image(dat_file, 'Scanrotoffy / OffsetY')
        real_img_size_x_px = self.get_data_for_image(dat_file, "Num.X / Num.X")
        real_img_size_y_px = self.get_data_for_image(dat_file, "Num.Y / Num.Y")
        real_img_size_x_dac = self.convert_pixel_to_dac(real_img_size_x_px, img_file=img_file)
        #real_img_size_y_dac = self.convert_pixel_to_dac(real_img_size_y_px, img_file=img_file)
        origin_position_of_overview_image_dac = np.array([global_image_offset_x_dac - real_img_size_x_dac/2, global_image_offset_y_dac])
        return self.convert_dac_to_nm(origin_position_of_overview_image_dac, img_file=img_file)


    def plot_moieties(self, img_file, plot=True):
        """
        Plot the overview image and the measured target image with the determined contours.
        """
        if img_file.endswith(".dat"):
            img_file = img_file.replace(".dat", ".dat.jpeg")
        dat_file = self.get_datafile_for_image(img_file)
        real_img_size = self.get_data_for_image(dat_file, "Num.X / Num.X")
        origin_img_nm = self.get_origin_of_image_nm(img_file)


        # Plot the moiety positions and contours in the overview image
        if plot:
            overview_img = cv2.imread(img_file)
            overview_img = cv2.resize(overview_img, (real_img_size, real_img_size))

            for i, moiety_position_nm in enumerate(self.moiety_position_nm):
                cv2.drawMarker(overview_img, np.array(self.convert_nm_to_pixel(moiety_position_nm-origin_img_nm, img_file=img_file)).astype(int), (106, 24, 72), markerType=cv2.MARKER_CROSS, markerSize=20, thickness=2)

                if self.moiety_target_contour_nm[i] is not None:
                    #cv2.drawContours(overview_img, np.array(self.convert_nm_to_pixel(self.moiety_target_contour_nm[i]-origin_img_nm, img_file=img_file)).astype(int), -1, (255, 0, 0), 2)
                    cv2.drawMarker(overview_img, np.array(self.convert_nm_to_pixel(moiety_position_nm-origin_img_nm, img_file=img_file)).astype(int), (96, 32, 0), markerType=cv2.MARKER_CROSS, markerSize=20, thickness=2)
                    cv2.drawContours(overview_img, np.array(self.convert_nm_to_pixel(self.moiety_matching_reference_contour_nm[i]-origin_img_nm, img_file=img_file)).astype(int), -1, (96, 32, 0), 2)
                    # Draw the orientation of the moiety
                    cv2.line(overview_img, tuple(np.array(self.convert_nm_to_pixel(moiety_position_nm-origin_img_nm, img_file=img_file)).astype(int)),
                        (int(np.array(self.convert_nm_to_pixel(moiety_position_nm-origin_img_nm, img_file=img_file)).astype(int)[0] + 50*np.cos(self.moiety_orientation_rad[i])),
                        int(np.array(self.convert_nm_to_pixel(moiety_position_nm-origin_img_nm, img_file=img_file)).astype(int)[1] + 50*np.sin(self.moiety_orientation_rad[i]))), (96, 32, 0), 2)

            cv2.imshow("Overview Image", overview_img)
            cv2.waitKey(0)
            cv2.destroyAllWindows()


    def convert_image_pixel_to_nm(self, img_file):
        # Check if the dat file exists
        if os.path.exists(img_file+".dat"):
            dat_file = img_file+".dat"
        else:
            dat_file = img_file.replace(".jpeg", ".dat")
            if not os.path.exists(dat_file):
                dat_file = img_file.replace(".dat.jpeg", ".dat")

        CONSTANT_AD_CONVERTER = 524287
        # --- These are just defaul values. They are overwritten by the values from the STM/AFM program if the connection is established.
        deltaX = self.get_data_for_image(dat_file, "Delta X / Delta X [Dac]")
        deltaY = self.get_data_for_image(dat_file, "Delta Y / Delta Y [Dac]")

        numX = self.get_data_for_image(dat_file, "Num.X")
        numY = self.get_data_for_image(dat_file, "Num.Y")

        GainX = self.get_data_for_image(dat_file, "GainX / GainX")
        GainY = self.get_data_for_image(dat_file, "GainY / GainY")

        Xpiezoconst = self.get_data_for_image(dat_file, "Xpiezoconst / Xpiezoconst")
        Ypiezoconst = self.get_data_for_image(dat_file, "YPiezoconst / YPiezoconst")

        assert deltaX*numX/CONSTANT_AD_CONVERTER*GainX*Xpiezoconst == deltaY*numY/CONSTANT_AD_CONVERTER*GainY*Ypiezoconst

        return deltaX*numX/CONSTANT_AD_CONVERTER*GainX*Xpiezoconst

    def convert_pixel_to_dac(self, pixel, img_file):
        # Check if the dat file exists
        if os.path.exists(img_file+".dat"):
            dat_file = img_file+".dat"
        else:
            dat_file = img_file.replace(".jpeg", ".dat")
            if not os.path.exists(dat_file):
                dat_file = img_file.replace(".dat.jpeg", ".dat")

        # --- These are just defaul values. They are overwritten by the values from the STM/AFM program if the connection is established.
        deltaX = self.get_data_for_image(dat_file, "Delta X / Delta X [Dac]")
        deltaY = self.get_data_for_image(dat_file, "Delta Y / Delta Y [Dac]")

        assert deltaX == deltaY

        return pixel*deltaX

    def convert_pixel_to_nm(self, pixel, img_file):
        # Check if the dat file exists
        if os.path.exists(img_file+".dat"):
            dat_file = img_file+".dat"
        else:
            dat_file = img_file.replace(".jpeg", ".dat")
            if not os.path.exists(dat_file):
                dat_file = img_file.replace(".dat.jpeg", ".dat")

        CONSTANT_AD_CONVERTER = 524287
        # --- These are just defaul values. They are overwritten by the values from the STM/AFM program if the connection is established.
        deltaX = self.get_data_for_image(dat_file, "Delta X / Delta X [Dac]")
        deltaY = self.get_data_for_image(dat_file, "Delta Y / Delta Y [Dac]")

        numX = self.get_data_for_image(dat_file, "Num.X")
        numY = self.get_data_for_image(dat_file, "Num.Y")

        GainX = self.get_data_for_image(dat_file, "GainX / GainX")
        GainY = self.get_data_for_image(dat_file, "GainY / GainY")

        Xpiezoconst = self.get_data_for_image(dat_file, "Xpiezoconst / Xpiezoconst")
        Ypiezoconst = self.get_data_for_image(dat_file, "YPiezoconst / YPiezoconst")

        assert deltaX/CONSTANT_AD_CONVERTER*GainX*Xpiezoconst == deltaY/CONSTANT_AD_CONVERTER*GainY*Ypiezoconst

        return pixel*deltaX*GainX*Xpiezoconst/CONSTANT_AD_CONVERTER

    def convert_nm_to_pixel(self, nm, img_file):
        # Check if the dat file exists
        if os.path.exists(img_file+".dat"):
            dat_file = img_file+".dat"
        else:
            dat_file = img_file.replace(".jpeg", ".dat")
            if not os.path.exists(dat_file):
                dat_file = img_file.replace(".dat.jpeg", ".dat")

        CONSTANT_AD_CONVERTER = 524287
        # --- These are just defaul values. They are overwritten by the values from the STM/AFM program if the connection is established.
        deltaX = self.get_data_for_image(dat_file, "Delta X / Delta X [Dac]")
        deltaY = self.get_data_for_image(dat_file, "Delta Y / Delta Y [Dac]")

        numX = self.get_data_for_image(dat_file, "Num.X")
        numY = self.get_data_for_image(dat_file, "Num.Y")

        GainX = self.get_data_for_image(dat_file, "GainX / GainX")
        GainY = self.get_data_for_image(dat_file, "GainY / GainY")

        Xpiezoconst = self.get_data_for_image(dat_file, "Xpiezoconst / Xpiezoconst")
        Ypiezoconst = self.get_data_for_image(dat_file, "YPiezoconst / YPiezoconst")

        assert deltaX/CONSTANT_AD_CONVERTER*GainX*Xpiezoconst == deltaY/CONSTANT_AD_CONVERTER*GainY*Ypiezoconst

        return nm*CONSTANT_AD_CONVERTER/(deltaX*GainX*Xpiezoconst)

    def convert_dac_to_nm(self, dac, img_file):
        # Check if the dat file exists
        if os.path.exists(img_file+".dat"):
            dat_file = img_file+".dat"
        else:
            dat_file = img_file.replace(".jpeg", ".dat")
            if not os.path.exists(dat_file):
                dat_file = img_file.replace(".dat.jpeg", ".dat")

        CONSTANT_AD_CONVERTER = 524287
        # --- These are just defaul values. They are overwritten by the values from the STM/AFM program if the connection is established.

        GainX = self.get_data_for_image(dat_file, "GainX / GainX")
        GainY = self.get_data_for_image(dat_file, "GainY / GainY")

        Xpiezoconst = self.get_data_for_image(dat_file, "Xpiezoconst / Xpiezoconst")
        Ypiezoconst = self.get_data_for_image(dat_file, "YPiezoconst / YPiezoconst")

        assert GainX*Xpiezoconst/CONSTANT_AD_CONVERTER == GainY*Ypiezoconst/CONSTANT_AD_CONVERTER

        return dac*GainX*Xpiezoconst/CONSTANT_AD_CONVERTER

    def get_data_for_image(self, file, *args):
        """ Get *.dat file entry for specific arguments.

            Parameters
            ----------------
            args | str
                Keyword for which the value is read from the *.dat file.

            Return
            ----------------
            value_read | str, float, int
        """

        with open(file, "r", encoding="ISO-8859-1") as f:
            lines = f.readlines()

        values = []
        for arg in args:
            for line in lines:
                if arg in line:
                    values.append(line.split("=")[1].strip())

        # Convert value to int, float, or str
        if len(values) == 1:
            try:
                return int(values[0])
            except ValueError:
                try:
                    return float(values[0])
                except ValueError:
                    return values[0]

        return np.array(values)


    def get_observation(self):
        return self.moiety_types,\
                self.moiety_position_nm,\
                self.moiety_orientation_rad,\
                self.moiety_bbox_width_height_nm,\
                self.confidence,\
                self.moiety_target_contour_nm,\
                self.moiety_matching_reference_contour_nm,\
                self.moiety_types_before,\
                self.moiety_type_changes
