# Author: Michael Obermayr
# set up whole dataset for object detection with yolov8
# labelling, augmentation and splitting into train and validation set

#%% import libraries and configuration
import os
import cv2
import yaml
import shutil
import random
import datetime
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt

from labelling_tool import *
from augmentation_tool import *


# Set the cwd to the directory containing the script
os.chdir(os.path.dirname(os.path.abspath(__file__)))

# import configuration from config.yaml
with open(os.path.normpath(os.path.join(os.getcwd(),"config.yaml")), "r") as ymlfile:
    cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)

input_dir = cfg["input_dir"]
output_dir = cfg["training_dir"]+cfg["dataset_name"]
all_classes = cfg["classes"]
number_augmentations = cfg["number_augmentations"]
type_augmentations = cfg["type_augmentations"]
augmentation_params = cfg["augmentation_params"]

scaling_factors = augmentation_params["scaling_factors"]
shift_factor = augmentation_params["shift_factor"]
brightness_factor = augmentation_params["brightness_factor"]
contrast_factor = augmentation_params["contrast_factor"]
gaussblur_kernel = augmentation_params["gaussblur_kernel"]
compression_quality = augmentation_params["compression_quality"]

# 16 colors for maximum 16 classes
all_colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (0, 255, 255), (255, 255, 0), (255, 0, 255),
            (255, 100, 0), (255, 255, 100), (0, 255, 100), (50, 255, 255), (50, 0, 255), (255, 50, 255),
            (255, 50, 0), (255, 255, 50), (100, 255, 0), (100, 255, 255), (0, 100, 255), (255, 100, 255),
            (255, 0, 100), (255, 255, 150), (50, 255, 50), (150, 255, 255), (0, 50, 255), (255, 150, 255)
            ]            
max_classes = len(all_colors)

#%% ======================================================================================================================
# define functions

def create_folder_structure():
        os.makedirs(os.path.join(output_dir, "raw_images"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "raw_labels"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "train/images"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "train/labels"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "valid/images"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "valid/labels"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "test/images"), exist_ok=True)
        os.makedirs(os.path.join(output_dir, "test/labels"), exist_ok=True)
        # create yaml file
        with open(os.path.join(output_dir, "data.yaml"), "w") as text_file:
            text_file.write("train: " + os.path.join(output_dir, "train/images") + "\n")
            text_file.write("val: " + os.path.join(output_dir, "valid/images") + "\n")
            text_file.write("test: " + os.path.join(output_dir, "test/images") + "\n")
            text_file.write("nc: " + str(len(all_classes)) + "\n")
            text_file.write("names: " + str(all_classes) + "\n")
            text_file.write("\n")
            text_file.write("metadata:\n")
            text_file.write("    created_by: Michael\n")
            text_file.write("    date: " + str(datetime.datetime.now()) + "\n")
            text_file.write("    Original dataset: " + str(input_dir) + "\n")
            text_file.write("    Number of augmentations: " + str(number_augmentations) + "\n")
            text_file.write("    Type of augmentations: " + str(type_augmentations) + "\n")
            text_file.write("    Augmentation parameters: " + str(augmentation_params) + "\n")


#%% ======================================================================================================================
# # labelling
do_labelling = False
show_instructions = True
create_folder_structure()

if do_labelling:
    # label all images in input directory
    for image_dir in get_all_images(output_dir, include_sxm=True):
        if image_dir.endswith(".sxm"):
            # convert .sxm to .jpeg
            image = convert_sxm_to_image(image_dir)
        else:
            # read image
            image = cv2.imread(image_dir)
        image_name = os.path.splitext(os.path.basename(image_dir))[0] # get image name

        current_class = label_an_image(image, output_dir, image_name, show_instructions=show_instructions)


#%% ======================================================================================================================
# augmentation

# reload one image to check if everything is correct
training_dir = os.path.normpath(os.path.join(output_dir, "train"))
test_dir = os.path.normpath(os.path.join(output_dir, "test"))

# copy raw images and labels into training folder and raw images into test folder
raw_image_dir = os.path.join(output_dir, "raw_images")
for original_image in get_all_images(raw_image_dir, include_sxm=True):
    shutil.copy(original_image, os.path.join(test_dir, "images"))
    shutil.copy(original_image, os.path.join(training_dir, "images"))
raw_label_dir = os.path.join(output_dir, "raw_labels")
for original_label in get_all_labels(raw_label_dir):
    shutil.copy(original_label, os.path.join(training_dir, "labels"))

# open first image in directory
input_images = sorted(os.listdir(os.path.join(training_dir, "images")))
input_labels = sorted(os.listdir(os.path.join(training_dir, "labels")))

example_image_dir = os.path.join(training_dir, "images", input_images[0])
example_label_dir = os.path.join(training_dir, "labels", input_labels[0])

example_image = cv2.imread(example_image_dir)
example_label = np.loadtxt(example_label_dir, delimiter=" ", ndmin=2)

# define augmentation pipeline and visualize one example
transform = build_transform(example_image, example_label)

# augment 10 images each into validation folder
validation_dir = os.path.join(output_dir, "valid")
augment_data(training_dir, 10, transform, validation_dir)

# copy all original image from input folder into test folder

for original_image in get_all_images(output_dir):
    shutil.copy(original_image, os.path.join(test_dir, "images"))


# augment all images in train folder
augment_data(training_dir, number_augmentations, transform)

# %%
