#%% import libraries and configuration
import os
import cv2
import yaml
import numpy as np
import datetime

from object_recognition.post_processing.smxReader import NanonisSXM
# import configuration from config.yaml

# Set the cwd to the directory containing the script
os.chdir(os.path.dirname(os.path.abspath(__file__)))

with open(os.path.normpath(os.path.join(os.getcwd(),"config.yaml")), "r") as ymlfile:
    cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)

input_dir = cfg["input_dir"]
output_dir = cfg["training_dir"]#+cfg["dataset_name"]
all_classes = cfg["classes"]
number_augmentations = cfg["number_augmentations"]
type_augmentations = cfg["type_augmentations"]

# 16 colors for maximum 16 classes
all_colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255), (255, 255, 0), (255, 0, 255),
            (255, 100, 0), (255, 255, 100), (0, 255, 100), (50, 255, 255), (50, 0, 255), (255, 50, 255),
            (255, 50, 0), (255, 255, 50), (100, 255, 0), (100, 255, 255), (0, 100, 255), (255, 100, 255),
            (255, 0, 100), (255, 255, 150), (50, 255, 50), (150, 255, 255), (0, 50, 255), (255, 150, 255)
            ]            
max_classes = len(all_colors)

def convert_sxm_to_image(img_file):
    if img_file.endswith(".sxm"):
        # read in the .sxm file
        load = NanonisSXM(img_file)
        assert len(load.channels_name) >= 4, "The measurement does not contain the necessary channels: Current, Z, Y, X"

        current = load.retrieve_channel_data('Current')
        zz = load.retrieve_channel_data('Z')
        yy = load.retrieve_channel_data('Y')
        xx = load.retrieve_channel_data('X')


        scan_dir = load.header['SCAN_DIR'][0][0]
        pixels = {'x': int(load.header['SCAN_PIXELS'][0][0]),
                    'y': int(load.header['SCAN_PIXELS'][0][1])}
        real_nm = {'x': 1e9 * float(load.header['SCAN_RANGE'][0][0]),
                    'y': 1e9 * float(load.header['SCAN_RANGE'][0][1])}
        offset_nm = {'x': 1e9 * float(load.header['SCAN_OFFSET'][0][0]),
                        'y': 1e9 * float(load.header['SCAN_OFFSET'][0][1])}
        offset_nm = (offset_nm['x'], offset_nm['y'])
        if scan_dir == 'up':
            data = np.flip(current, axis=0), np.flip(zz, axis=0), np.flip(yy, axis=0), np.flip(xx, axis=0)
        else:
            data = current, zz, yy, xx

        # convert z-topography to uint8
        img_Z_m = data[1]
        # Rescale values of m to pixel values between 0 and 255
        img_Z_m_min = np.min(img_Z_m)
        img_Z_m_max = np.max(img_Z_m)
        if img_Z_m_max == img_Z_m_min:
            return np.zeros_like(img_Z_m, dtype=np.uint8)
        img_Z = ((img_Z_m - img_Z_m_min) / (img_Z_m_max - img_Z_m_min) * 255).astype(np.uint8)
        # Convert to RGB image
        img_Z_rgb = cv2.cvtColor(img_Z, cv2.COLOR_GRAY2RGB)

        return img_Z_rgb
    else:
        return img_file

#%% define functions ====================================================================================================
def get_all_images(dataset_dir, include_sxm=False):
    list_of_images = []
    for x in os.listdir(dataset_dir):
        if include_sxm:
            if x.endswith(".sxm"):
                list_of_images.append(str(dataset_dir)+'/'+str(x))
        if x.endswith(".png") or x.endswith(".jpeg") or x.endswith(".jpg"): 
            list_of_images.append(str(dataset_dir)+'/'+str(x))
    
    return list_of_images

def get_all_labels(dataset_dir):
    list_of_images = []
    for x in os.listdir(dataset_dir):
        if x.endswith(".txt"):
            list_of_images.append(str(dataset_dir)+'/'+str(x))
    
    return list_of_images

def convert_boxes_to_yolo(bboxes, image_shape):
    yolo_label = []
    for box in bboxes:
        x = box[1][0]
        y = box[1][1]
        w = box[2][0] - box[1][0]
        h = box[2][1] - box[1][1]

        # normalize values
        xc = float((x + w/2)/image_shape[1])
        yc = float((y + h/2)/image_shape[0])
        wc = float(w/image_shape[1])
        hc = float(h/image_shape[0])

        yolo_label.append(' '.join([str(box[0]), str(xc), str(yc), str(wc), str(hc)]))

    return yolo_label

def save_yolo_boxes(output_dir, name, bboxes):
    # create txt file
    with open(output_dir+"/"+name+".txt", "w") as text_file:
        text_file.write("\n".join(bboxes))
    
def resize_image_and_labels(image, start_and_end_points, input_shape, target_shape):
    # resize image to target shape, pad image with mean value of borders
    borders = np.concatenate([image[0, :], image[-1, :], image[:, 0], image[:, -1]])
    padding_value = np.mean(borders, axis = 0)
    pad_height = int(np.floor((target_shape[0] - image.shape[0])/2))
    pad_length = int(np.floor(int((target_shape[1] - image.shape[1])/2)))
    resized_image = cv2.copyMakeBorder(image, pad_height, pad_height+1, pad_length, pad_length+1, cv2.BORDER_CONSTANT, value=padding_value)

    # resize labels
    resized_labels = []
    for box in start_and_end_points:
        x_start = int(box[1][0]/input_shape*image.shape[1] + pad_length)
        y_start = int(box[1][1]/input_shape*image.shape[0] + pad_height)
        x_end = int(box[2][0]/input_shape*image.shape[1] + pad_length)
        y_end = int(box[2][1]/input_shape*image.shape[0] + pad_height)
        resized_labels.append([box[0], [x_start, y_start], [x_end, y_end]])

    return resized_image, resized_labels

def resize_labels(start_and_end_points, input_shape, target_shape):
    resized_labels = []
    for box in start_and_end_points:
        x_start = int(box[1][0]/input_shape*target_shape[1])
        y_start = int(box[1][1]/input_shape*target_shape[0])
        x_end = int(box[2][0]/input_shape*target_shape[1])
        y_end = int(box[2][1]/input_shape*target_shape[0])
        resized_labels.append([box[0], [x_start, y_start], [x_end, y_end]])
    
    return resized_labels

def instruction_window(show=True):

    if show:
        instructions = ["Press 'l' to change class label.", 
                        "Press 'd' to delete last drawn box.", 
                        "Press 'c' to crop image to last drawn box.", 
                        "Press 's' to save image and labels.", 
                        "Press 'q' to quit / label next image."]
        instructions = np.array(instructions)
        instructions = instructions.reshape(5, 1)

        # open a new cv2 window with white background
        instructions_image = np.ones((300, 800, 3), np.uint8)*200
        # add instructions to image
        for i in range(len(instructions)):
            cv2.putText(instructions_image, instructions[i][0], (50, 50*(i+1)), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
        
        cv2.imshow("Instructions", instructions_image)

def draw_rectangle(action, x, y, flags, *userdata):
    # drawing function, called when mouse events happen

    global bboxes, start_point, end_point, current_class, draw, cropping

    color = all_colors[current_class%max_classes]

    if action == cv2.EVENT_LBUTTONDOWN:
        start_point = [x, y]
        draw = True

    elif action == cv2.EVENT_LBUTTONUP:
        end_point = [x, y]
        draw = False

        # width or height of rectangle must be larger than 10 pixels
        if abs(start_point[0] - end_point[0]) > 10 or abs(start_point[1] - end_point[1]) > 10:
            # check if rectangle is drawn from top left to bottom right
            top_left = [min(start_point[0], end_point[0]), min(start_point[1], end_point[1])]
            bottom_right = [max(start_point[0], end_point[0]), max(start_point[1], end_point[1])]

            start_and_end_points.append([current_class, top_left, bottom_right])
            cv2.rectangle(image, start_point, end_point, color, 1)
            cv2.imshow("Annotation Tool", image)
        else:
            print("Rectangle too small, not drawn.")
            cv2.imshow("Annotation Tool", image)

    elif draw:
        temp_image = image.copy()
        cv2.rectangle(temp_image, start_point, (x, y), color, 1)
        cv2.imshow("Annotation Tool", temp_image)

def label_an_image(original_image, output_dir, filename, window_size=1000, show_instructions=True):
    # main labelling function
    global start_and_end_points, current_class
    global image
    global draw, cropping

    # create directories for output
    image_dir = os.path.join(output_dir, "raw_images")
    label_dir = os.path.join(output_dir, "raw_labels")

    empty_image = original_image.copy()


    empty_image = cv2.resize(empty_image, (window_size, window_size))
    image = empty_image.copy()
    try:
        if current_class != 0:
            pass
    except:
        current_class = 0
    cv2.imshow("Annotation Tool", image)
    cv2.setWindowTitle("Annotation Tool", "Current class: " + all_classes[current_class])
    instruction_window(show=show_instructions)

    start_and_end_points = []
    draw = False
    cropping = False

    while True:
        # draw the rectangle, if mouse is pressed    
        cv2.setMouseCallback("Annotation Tool", draw_rectangle)

        key = cv2.waitKey(1) & 0xFF

        if key == ord('q'):
            # close the window
            break

        elif key == ord('l'):
            # change class label
            current_class = (current_class + 1) % len(all_classes)
            cv2.setWindowTitle("Annotation Tool", "Current class: " + all_classes[current_class])

        elif key == ord('d'):
            # delete last drawn box if there is at least one box, or reset cropping
            if start_and_end_points:
                start_and_end_points.pop()
                image = empty_image.copy()
                for box in start_and_end_points:
                    cv2.rectangle(image, box[1], box[2], all_colors[box[0]%max_classes], 1)
                cv2.imshow("Annotation Tool", image)
            elif cropping:
                empty_image = original_image.copy()
                empty_image = cv2.resize(empty_image, (window_size, window_size))
                image = empty_image.copy()
                cv2.imshow("Annotation Tool", image)
                print("Cropping reset.")
                cropping = False
            else:
                print("No boxes to delete.")

        elif key == ord('c'):
            # crop image to last drawn box if there is at least one box
            if start_and_end_points and not cropping:
                top_left = start_and_end_points[-1][1]
                bottom_right = start_and_end_points[-1][2]

                crop_x1 = int(top_left[0]/window_size * original_image.shape[1])
                crop_x2 = int(bottom_right[0]/window_size * original_image.shape[1])
                crop_y1 = int(top_left[1]/window_size * original_image.shape[0])
                crop_y2 = int(bottom_right[1]/window_size * original_image.shape[0])

                original_crop = original_image[crop_y1:crop_y2, crop_x1:crop_x2]
                empty_image = original_crop.copy()
                empty_image = cv2.resize(empty_image, (window_size, window_size))
                image = empty_image.copy()
                cv2.imshow("Annotation Tool", image)
                # reset list of boxes
                start_and_end_points = []
                cropping = True
                print("Cropped image.")
            elif cropping:
                print("Already cropped.")
            else:
                print("No box to crop.")

        elif key == ord('s'):
            # save image and labels

            # resize image if it was cropped
            if cropping:
                # add an increasing number to new_filename, if file already exists
                new_filename = filename + "_cropped"
                i = 2
                while os.path.exists(os.path.join(image_dir, new_filename + ".jpeg")):
                    new_filename = filename + "_cropped_" + str(i)
                    i += 1

                resized_image, resized_labels = resize_image_and_labels(original_crop, start_and_end_points, window_size, original_image.shape)
            else:
                resized_labels = resize_labels(start_and_end_points, window_size, original_image.shape)
                resized_image = original_image.copy()
                new_filename = filename
            
            # convert image to greyscale
            resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
            cv2.imwrite(os.path.join(image_dir, new_filename + ".jpeg"), resized_image)
            yolo_boxes = convert_boxes_to_yolo(resized_labels, original_image.shape)
            save_yolo_boxes(label_dir, new_filename, yolo_boxes)
            print("Saved image and labels.")
            
    # close all open windows and save image and label to respective directories
    cv2.destroyAllWindows()

# %%
