# Yolov8, Object detection to work with configuration yaml file
# dataset of self-labeled and augmented images

import os
import cv2
import yaml
import math
import random
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
import matplotlib.patches as patches

#from scaling_tool import *
import torch
import datetime



#%% define functions

class ObjectSegmentation():

    def __init__(self, model_file = None):
        # Set the cwd to the directory containing the script
        #os.chdir(os.path.dirname(os.path.abspath(__file__)))

        with open(os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)),"config.yaml")), "r") as ymlfile:
            cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)

class ObjectRecognition():

    def __init__(self, model_file = None):
        # Set the cwd to the directory containing the script
        #os.chdir(os.path.dirname(os.path.abspath(__file__)))

        with open(os.path.normpath(os.path.join(os.path.dirname(os.path.abspath(__file__)),"config.yaml")), "r") as ymlfile:
            cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)

        self.input_dir = cfg["training_dir"]+cfg["dataset_name"]
        self.output_dir = cfg["output_dir"]
        self.test_dir = cfg["test_dir"]
        os.makedirs(self.output_dir, exist_ok=True)

        self.classes = cfg["classes"]
        self.epochs = cfg["epochs"]
        self.img_size = cfg["img_size"]

        self.autoscale = cfg["autoscale"]
        self.train_size_in_nm = cfg["train_size_in_nm"]
        self.train_size_in_pxl = self.img_size
        self.test_size_in_nm = cfg["test_size_in_nm"]

        self.model_file = model_file
        # load data
        self.data = os.path.normpath(os.path.join(self.input_dir, "data.yaml"))

        # set gpu acceleration
        self.device = 'cuda' if torch.cuda.is_available() else "cpu"
        print("Using Device: ", self.device)

        # Load a model
        self.model = self.load_model(model_file=self.model_file, load_model_dir=self.output_dir)

        # 16 colors for maximum 16 classes
        self.all_colors = np.array([(21, 96, 130), (0, 255, 0), (255, 0, 0), (0, 255, 255), (255, 255, 0), (255, 0, 255),
                    (255, 100, 0), (255, 255, 100), (0, 255, 100), (50, 255, 255), (50, 0, 255), (255, 50, 255),
                    (255, 50, 0), (255, 255, 50), (100, 255, 0), (100, 255, 255), (0, 100, 255), (255, 100, 255),
                    (255, 0, 100), (255, 255, 150), (50, 255, 50), (150, 255, 255), (0, 50, 255), (255, 150, 255)
                    ])/255
        #self.all_colors = ['r', 'g', 'b', 'y', 'c', 'm', 'k', 'w']
        self.max_classes = len(self.all_colors)

    def load_model(self, model_file, load_model_dir):
        try:
            latest_folder_created = max([os.path.join(load_model_dir, d) for d in os.listdir(load_model_dir) if d.startswith('train')], key=os.path.getmtime)
            best_and_latest_model = os.path.join(latest_folder_created, "weights", "best.pt")
        except:
            best_and_latest_model = []
        # load best model if no model file is given or init with default model
        if model_file is None and best_and_latest_model:
            # load the best and latest model
            if best_and_latest_model is not None:
                print(f"Loading: {best_and_latest_model}")
                model = YOLO(best_and_latest_model)
        elif model_file is not None:
            # load the model from the given file
            model = YOLO(model_file)
        else:
            # load the default model
            model = YOLO('yolov8n.yaml') #yolov8m.pt, yolov8s.pt, yolov8n.pt yolov8x.pt,

        # if self.device == "cuda":
        #     torch.cuda.set_device(0)
        #     model.to(device=self.device)

        return model

    def train(self):
        #%% train the network
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)

        # set name for output directory
        # look for number of existing folders in directory output_dir with the name "run_x" and increment number
        run_number = int(len([name for name in os.listdir(self.output_dir) if os.path.isdir(os.path.join(self.output_dir, name))])/2 +1)
        if run_number < 10:
            run_number = "00" + str(run_number)
        elif run_number < 100:
            run_number = "0" + str(run_number)

        train_output = os.path.join("train_" + str(run_number))
        val_output = os.path.join("val_" + str(run_number))

        
       
        # Train and validate the model
        self.train_results = self.model.train(data=self.data, epochs=self.epochs, imgsz=self.img_size, project = self.output_dir, name = train_output)
        metrics = self.model.val(data=self.data, project = self.output_dir, name = val_output)
        metrics.box.map
        metrics.box.maps

    def predict(self, img, confidence_threshold = 0.6):
        results = self.model.predict(img, device=self.device)

        # convert predicted boxes into numpy array
        pred_class = results[0].boxes.cls.cpu().numpy()
        #pred_class_str = np.array([self.classes[int(i)] for i in pred_class])
        pred_bbox = results[0].boxes.xywhn[:].cpu().numpy()
        pred_confidence = results[0].boxes.conf.cpu().numpy()
        self.pred = np.concatenate((pred_class.reshape(pred_class.size,1) , pred_bbox, pred_confidence.reshape(pred_class.size,1)), axis=1)
        
        # Remove bboxes within other bboxes
        pred_bbox, removed_indices = self.remove_overlapping_bboxes(pred_bbox)
        self.pred = np.delete(self.pred, removed_indices, axis=0)
        # Remove bbox with 80% overlap

        # Remove data with confidence below confidence threshold
        confidence = self.pred[:,5]
        #self.pred = self.pred[confidence > confidence_threshold] # XXX All confidence values have to be reviewed
        #confidence_threshold = 1 # XXX remove this later
        # If no predictions are found or the model is uncertain, human input is required
        if (confidence <= confidence_threshold).any() or confidence.size == 0:
            self.pred = self.object_recognition_with_manual_input(img, self.pred, confidence_threshold)

        return self.pred

    def object_recognition_with_manual_input(self, img, detections, confidence_threshold=0.5):
        """
        Object recognition with correction for low-confidence predictions.

        Parameters:
            image_path (str): Path to STM image.
            detections (list of tuples): Each detection is (label, x1, y1, x2, y2, confidence).
            self.classes (list of str): Available class labels (mapped to keys 1, 2, 3, ...).
            confidence_threshold (float): Threshold to auto-accept detections.

        Returns:
            list of tuples: (bbox, label)
        """
        if img is None:
            raise ValueError("Could not load image.")

        clone = img.copy()
        final_boxes = []
        final_labels = []
        final_confidence = []

        # Step 1: Draw all high-confidence predictions immediately
        for det in detections:
            if len(det) != 6:
                continue
            label, x1, y1, x2, y2, conf = det
            bbox = np.array([x1, y1, x2, y2])  # Scale to image size
            bbox_px = (x1*img.shape[1], y1*img.shape[0], x2*img.shape[1], y2*img.shape[0])
            if conf >= confidence_threshold:
                final_boxes.append(bbox)
                final_labels.append(int(label))
                final_confidence.append(conf)
                cv2.rectangle(clone, tuple(map(int,(bbox_px[0]-bbox_px[2]/2, bbox_px[1]-bbox_px[3]/2))), tuple(map(int,(bbox_px[0]+bbox_px[2]/2, bbox_px[1]+bbox_px[3]/2))), (0, 255, 0), 1)
                cv2.putText(clone, self.classes[int(label)], tuple(map(int,(bbox_px[0]-bbox_px[2]/2-10, bbox_px[1]-bbox_px[3]/2-5))),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)

        # Shared state for drawing
        drawing = [False]
        start_point = [None]
        temp_img = [clone.copy()]
        pending_box = [None]

        def draw_rectangle(event, x, y, flags, param):
            if event == cv2.EVENT_LBUTTONDOWN:
                drawing[0] = True
                start_point[0] = (x, y)
            elif event == cv2.EVENT_MOUSEMOVE and drawing[0]:
                temp_img[0] = clone.copy()
                cv2.rectangle(temp_img[0], start_point[0], (x, y), (0, 255, 0), 1)
            elif event == cv2.EVENT_LBUTTONUP:
                drawing[0] = False
                x1, y1 = start_point[0]
                x2, y2 = x, y
                x_min, y_min = min(x1, x2), min(y1, y2)
                x_max, y_max = max(x1, x2), max(y1, y2)
                pending_box[0] = [x_min, y_min, x_max, y_max]
                print("Box drawn. Press a number key to label it:")
                for idx, lbl in enumerate(self.classes, 1):
                    print(f"{idx}: {lbl}")

        print("Final values", final_boxes, final_labels, final_confidence)

        cv2.namedWindow("Image")
        cv2.setMouseCallback("Image", draw_rectangle)

        
        # Step 2: Review low-confidence detections
        for det in detections:
            if len(det) != 6:
                continue
            label, x1, y1, x2, y2, conf = det
            if conf >= confidence_threshold:
                continue  # already processed above

            label, x1, y1, x2, y2, conf = det
            bbox = np.array([x1, y1, x2, y2])  # Scale to image size
            bbox_px = (x1*img.shape[1], y1*img.shape[0], x2*img.shape[1], y2*img.shape[0])
            if conf >= confidence_threshold:
                final_boxes.append(bbox)
                final_labels.append(int(label))
                final_confidence.append(conf)
                cv2.rectangle(clone, tuple(map(int,(bbox_px[0]-bbox_px[2]/2, bbox_px[1]-bbox_px[3]/2))), tuple(map(int,(bbox_px[0]+bbox_px[2]/2, bbox_px[1]+bbox_px[3]/2))), (0, 255, 0), 1)
                cv2.putText(clone, self.classes[int(label)], tuple(map(int,(bbox_px[0]-bbox_px[2]/2-10, bbox_px[1]-bbox_px[3]/2-5))),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)




            bbox = np.array([x1, y1, x2, y2]) 
            bbox_px = (x1*img.shape[1], y1*img.shape[0], x2*img.shape[1], y2*img.shape[0])
            preview = clone.copy()
            cv2.rectangle(preview, tuple(map(int,(bbox_px[0]-bbox_px[2]/2, bbox_px[1]-bbox_px[3]/2))), tuple(map(int,(bbox_px[0]+bbox_px[2]/2, bbox_px[1]+bbox_px[3]/2))), (0, 0, 255), 2)
            cv2.putText(preview, f"({conf:.2f}) {self.classes[int(label)]}", tuple(map(int,(bbox_px[0]-bbox_px[2]/2-10, bbox_px[1]-bbox_px[3]/2-5))),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1)
            cv2.imshow("Image", preview)
            print(f"Low-confidence detection: {label} @ {bbox} (confidence: {conf:.2f})")
            print("Press 'o' to accept, 'x' to decline, 'd' to delete the last entry or draw a new box to relabel")

            key = cv2.waitKey(0) & 0xFF
            if key == ord('o') or key == ord('O'):
                final_boxes.append(bbox)
                final_labels.append(int(label))
                final_confidence.append(conf)
                cv2.rectangle(clone, tuple(map(int,(bbox_px[0]-bbox_px[2]/2, bbox_px[1]-bbox_px[3]/2))), tuple(map(int,(bbox_px[0]+bbox_px[2]/2, bbox_px[1]+bbox_px[3]/2))), (0, 255, 0), 1)
                cv2.putText(clone, self.classes[int(label)], tuple(map(int,(bbox_px[0]-bbox_px[2]/2-10, bbox_px[1]-bbox_px[3]/2-5))),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
                continue
            elif key == ord('x') or key == ord('X'):
                continue
            # Delete last entry
            elif key == ord('d') or key == ord('D'):
                if final_boxes:
                    final_boxes.pop()
                    final_labels.pop()
                    final_confidence.pop()
                    clone = img.copy()

            # Else, user draws and labels
            print("Draw a new bounding box...")
            pending_box[0] = None
            while pending_box[0] is None:
                display = temp_img[0] if drawing[0] else clone
                cv2.imshow("Image", display)
                k = cv2.waitKey(1) & 0xFF
                if 49 <= k <= 57 and pending_box[0]:
                    idx = k - 49
                    if 0 <= idx < len(self.classes):
                        new_label = self.classes[idx]

                        # Adjust pending_box values are not within the image: lower bound is 0, upper bound is image size
                        if pending_box[0][0] < 0: pending_box[0][0] = 0
                        if pending_box[0][1] < 0: pending_box[0][1] = 0
                        if pending_box[0][2] > display.shape[1]: pending_box[0][2] = display.shape[1]
                        if pending_box[0][3] > display.shape[0]: pending_box[0][3] = display.shape[0]
                        x1_, y1_, x2_, y2_ = pending_box[0]

                        # Convert to YOLO format
                        x_center = (x1_ + x2_) / 2
                        y_center = (y1_ + y2_) / 2
                        w = (x2_ - x1_)
                        h = (y2_ - y1_)
                        conf = 1.0  # Confidence for manual input

                        final_boxes.append(np.asarray([x_center, y_center, w, h])/display.shape[0])
                        final_labels.append(int(new_label))
                        final_confidence.append(conf)     
                        
                        cv2.rectangle(clone, (x1_, y1_), (x2_, y2_), (0, 255, 0), 1)
                        cv2.putText(clone, new_label, tuple(map(int,(bbox_px[0]-bbox_px[2]/2-10, bbox_px[1]-bbox_px[3]/2-5))),
                                    cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
                        pending_box[0] = None
                        break

        # Step 3: Optional additional manual boxes
        print("Add more boxes manually or press 'q' to finish.")
        while True:
            display = temp_img[0] if drawing[0] else clone
            cv2.imshow("Image", display)
            key = cv2.waitKey(1) & 0xFF
            if key == ord('q'):
                # Save labeld data to the outputwith the drawn boxes which is indicated by a confidence of 1.0
                break
            if 49 <= key <= 57 and pending_box[0]:
                idx = key - 49
                if 0 <= idx < len(self.classes):
                    moiety_class = self.classes[idx]
                    label = idx

                    # Adjust pending_box values are not within the image: lower bound is 0, upper bound is image size
                    if pending_box[0][0] < 0: pending_box[0][0] = 0
                    if pending_box[0][1] < 0: pending_box[0][1] = 0
                    if pending_box[0][2] > display.shape[1]: pending_box[0][2] = display.shape[1]
                    if pending_box[0][3] > display.shape[0]: pending_box[0][3] = display.shape[0]
                    x1, y1, x2, y2 = pending_box[0]
                    
                    # Convert to YOLO format
                    x_center = (x1 + x2) / 2
                    y_center = (y1 + y2) / 2
                    w = (x2 - x1)
                    h = (y2 - y1)
                    conf = 1.0  # Confidence for manual input

                    final_boxes.append(np.asarray([x_center, y_center, w, h])/display.shape[0])

                    final_labels.append(int(label))
                    final_confidence.append(1.0)  # Nothing predicted     
                    
                    cv2.rectangle(clone, (x1, y1), (x2, y2), (0, 255, 0), 1)
                    cv2.putText(clone, moiety_class, (x1, y1 - 10),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1)
                    pending_box[0] = None
            
            if key == ord('s'):
                # import configuration from config.yaml
                with open(os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)),"config.yaml")), "r") as ymlfile:
                    cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)
                output_dir = cfg["training_dir"]+cfg["dataset_name"]
                # create directories for output
                image_dir = os.path.join(output_dir, "raw_images")
                label_dir = os.path.join(output_dir, "raw_labels")
                filename_per_timestep = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+'_manual_input'
                # save image
                cv2.imwrite(os.path.join(image_dir, filename_per_timestep + ".jpeg"), img)
                # save label
                with open(os.path.join(label_dir, filename_per_timestep+".txt"), "w") as text_file:
                    for label, box in zip(final_labels, final_boxes):
                        line = str(label) + " " + " ".join(f"{num:.8f}" for num in box)
                        text_file.write(line + "\n")
                print("Saved image and labels.")

        cv2.destroyAllWindows()
        #return np.array(np.asarray(final_labels), np.array(final_boxes), np.asarray(final_confidence).T)
        return np.concatenate((np.array(final_labels).reshape(np.array(final_labels).size,1) , final_boxes, np.array(final_confidence).reshape(np.array(final_labels).size,1)), axis=1)

    def remove_inner_bboxes(self, bboxes):
        def is_inside(box1, box2, eps=1e-6):
            x1, y1, w1, h1 = box1
            x2, y2, w2, h2 = box2

            return (
                x1 - w1 / 2 >= x2 - w2 / 2 - eps and
                y1 - h1 / 2 >= y2 - h2 / 2 - eps and
                x1 + w1 / 2 <= x2 + w2 / 2 + eps and
                y1 + h1 / 2 <= y2 + h2 / 2 + eps
            )

        filtered_bboxes = []
        removed_indices = []
        for i, box1 in enumerate(bboxes):
            inside_any = False
            for j, box2 in enumerate(bboxes):
                if i != j and is_inside(box1, box2):
                    inside_any = True
                    removed_indices.append(i)
                    break
            if not inside_any:
                filtered_bboxes.append(box1)
        
        return np.array(filtered_bboxes, dtype=np.float32), removed_indices
    
    def remove_overlapping_bboxes(self, bboxes, iou_threshold=0.8):
        def compute_iou(box1, box2):
            def to_corners(box):
                x, y, w, h = box
                return [x - w / 2, y - h / 2, x + w / 2, y + h / 2]

            x1_min, y1_min, x1_max, y1_max = to_corners(box1[:4])
            x2_min, y2_min, x2_max, y2_max = to_corners(box2[:4])

            inter_x1 = max(x1_min, x2_min)
            inter_y1 = max(y1_min, y2_min)
            inter_x2 = min(x1_max, x2_max)
            inter_y2 = min(y1_max, y2_max)

            inter_area = max(0, inter_x2 - inter_x1) * max(0, inter_y2 - inter_y1)
            area1 = (x1_max - x1_min) * (y1_max - y1_min)
            area2 = (x2_max - x2_min) * (y2_max - y2_min)
            union_area = area1 + area2 - inter_area

            if union_area == 0:
                return 0.0
            return inter_area / union_area

        keep = [True] * len(bboxes)
        removed_indices = []

        for i, box1 in enumerate(bboxes):
            if not keep[i]:
                continue
            for j, box2 in enumerate(bboxes):
                if i != j and keep[j]:
                    iou = compute_iou(box1, box2)
                    if iou > iou_threshold:
                        # Remove the box with lower confidence
                        if box1[4] >= box2[4]:
                            keep[j] = False
                            removed_indices.append(j)
                        else:
                            keep[i] = False
                            removed_indices.append(i)
                            break  # box1 is removed, no need to compare further

        filtered_bboxes = [box for i, box in enumerate(bboxes) if keep[i]]
        return np.array(filtered_bboxes, dtype=np.float32), removed_indices


    def plot_predictions(self, image, pred, output_dir = None, name = None, confidence = True, numbering = True, plot = False):
        """
        Plot the image with the predicted boxes in yolo format.
        """
        if plot:
            fig, ax = plt.subplots(figsize=(10, 10))
            # fig.suptitle('Predictions')

            ax.imshow(image)
            ax.axis('off')

            for i in range(pred.shape[0]):
                box = pred[i,:]
                        
                # resize box to image size
                x_center = box[1] * image.shape[1]
                y_center = box[2] * image.shape[0]
                width = box[3] * image.shape[1]
                height = box[4] * image.shape[0]

                # plot the rectangle with matplotlib
                anchor = (x_center - width/2, y_center - height/2)
                rect = patches.Rectangle(anchor, width, height, linewidth=4, alpha = 0.8, edgecolor=self.all_colors[int(box[0])], facecolor='none')
                ax.add_patch(rect)

                # add confidence value to rectangle, if flag is set
                if confidence:
                    confidence = str(box[5])
                    ax.text(anchor[0], anchor[1], confidence, fontsize=10, color=self.all_colors[int(box[0])])

                # add numbering to rectangle, if flag is set
                if numbering:
                    ax.text(anchor[0], anchor[1]-10, str(i), fontsize=10, color=self.all_colors[int(box[0])])

            # add legend for all classes with their respective color
            legend_elements = []
            for i in range(len(self.classes)):
                legend_elements.append(patches.Patch(facecolor=self.all_colors[i], edgecolor=self.all_colors[i], label=self.classes[i]))
            ax.legend(handles=legend_elements, loc='best', prop={'size': 18}) #bbox_to_anchor=(1, 1), 
            plt.show()

            if output_dir is not None and name is not None:
                full_name = os.path.join(output_dir, name)+".eps"
                if confidence:
                    full_name = os.path.join(output_dir, name)+"_c.eps"
                fig.savefig(full_name, format="eps")


    def get_obstacles(self, image, image_path, pred = None):
        # Check if obstacles data is already saved for the image
        filename = os.path.normpath(os.path.splitext(image_path)[0]+".yaml")
        if os.path.exists(filename):
            with open(filename, "r") as ymlfile:
                self._obstacles_px = yaml.load(ymlfile, Loader=yaml.FullLoader)
            return self._obstacles_px
        
        clone = image.copy()
        self._obstacles_px = []
        current_polygon = []

        def click_event(event, x, y, flags, param):
            nonlocal current_polygon, clone

            if event == cv2.EVENT_LBUTTONDOWN:
                # Add a point on left click
                current_polygon.append((x, y))
                # Draw a small circle at the clicked point
                cv2.circle(clone, (x, y), 3, (0, 255, 0), -1)
                # Draw lines if more than 1 point
                if len(current_polygon) > 1:
                    cv2.line(clone, current_polygon[-2], current_polygon[-1], (255, 0, 0), 2)

            elif event == cv2.EVENT_RBUTTONDOWN:
                # Finish the polygon on right click
                if len(current_polygon) > 2:
                    cv2.line(clone, current_polygon[-1], current_polygon[0], (255, 0, 0), 2)
                    self._obstacles_px.append(current_polygon.copy())
                current_polygon = []

        cv2.namedWindow("Image")
        cv2.setMouseCallback("Image", click_event)

        print("=== Setup obstacles ===")
        print("u ... undo last point")
        print("q ... quit")


        while True:
            cv2.imshow("Image", clone)
            key = cv2.waitKey(1) & 0xFF

            if key == ord('q'):  # Press 'q' to quit
                if current_polygon:
                    # Save incomplete polygon if any points exist
                    if len(current_polygon) > 2:
                        self._obstacles_px.append(current_polygon.copy())
                    current_polygon = []
                break

            elif key == ord('u'):  # Press 'u' to undo last point
                if current_polygon:
                    current_polygon.pop()
                    clone = image.copy()
                    # Redraw all polygons
                    for poly in self._obstacles_px:
                        for i in range(len(poly)):
                            cv2.circle(clone, poly[i], 3, (0, 255, 0), -1)
                            cv2.line(clone, poly[i], poly[(i + 1) % len(poly)], (255, 0, 0), 2)
                    # Redraw current polygon
                    for i in range(len(current_polygon)):
                        cv2.circle(clone, current_polygon[i], 3, (0, 255, 0), -1)
                        if i > 0:
                            cv2.line(clone, current_polygon[i-1], current_polygon[i], (255, 0, 0), 2)

        cv2.destroyAllWindows()
        
        # save obstacle coordinates to file
        with open(filename, "w") as outfile:
            yaml.dump(self._obstacles_px, outfile, default_flow_style=False)

        return self._obstacles_px


if __name__ == "__main__":
    # Train the model using the settings in the config.yaml file
    # and save the model and validation data to the output directory
    # # CUDA 12.1 works with torch 2.6.0: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
    obj_recog = ObjectRecognition()
    obj_recog.train()