# augments images with albumentations to inflate tiny dataset
#%% import packages and configuration
import os
import random
import numpy as np
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import yaml



# import configuration from config.yaml

# Set the cwd to the directory containing the script
os.chdir(os.path.dirname(os.path.abspath(__file__)))

with open(os.path.normpath(os.path.join(os.getcwd(),"config.yaml")), "r") as ymlfile:
    cfg = yaml.load(ymlfile, Loader=yaml.FullLoader)

training_dir = cfg["training_dir"]
output_dir = cfg["training_dir"]+cfg["dataset_name"]
classes = cfg["classes"]
number_augmentations = cfg["number_augmentations"]
type_augmentations = cfg["type_augmentations"]
augmentation_params = cfg["augmentation_params"]

scaling_factors = augmentation_params["scaling_factors"]
shift_factor = augmentation_params["shift_factor"]
brightness_factor = augmentation_params["brightness_factor"]
contrast_factor = augmentation_params["contrast_factor"]
gaussblur_kernel = augmentation_params["gaussblur_kernel"]
compression_quality = augmentation_params["compression_quality"]

# build dictionary from classes, which consists of an integer and a name
category_id_to_name = {}
for i in range(len(classes)):
    category_id_to_name[i] = classes[i]


#%% define functions

def visualize_bbox(img, bbox, class_name, thickness=2):
    """Visualizes a single bounding box on the image"""
    BOX_COLOR = (255, 0, 0) # Red
    TEXT_COLOR = (255, 255, 255) # White
    
    x_center, y_center, w, h = bbox
    x_min = round(x_center - w / 2, 2)
    y_min = round(y_center - h / 2, 2)
    x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h)
   
    cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=BOX_COLOR, thickness=thickness)
    
    ((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)    
    cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
    cv2.putText(
        img,
        text=class_name,
        org=(x_min, y_min - int(0.3 * text_height)),
        fontFace=cv2.FONT_HERSHEY_SIMPLEX,
        fontScale=0.35, 
        color=TEXT_COLOR, 
        lineType=cv2.LINE_AA,
    )
    return img


def visualize_image(image, bboxes, category_ids, category_id_to_name):
    bboxes = bboxes * image.shape[0]
    img = image.copy()
    for bbox, category_id in zip(bboxes, category_ids):
        class_name = category_id_to_name[category_id]
        img = visualize_bbox(img, bbox, class_name)
    plt.figure(figsize=(12, 12))
    plt.axis('off')
    plt.imshow(img)

def check_boxes(bboxes):
    # check if boxes are out of bounds, correct if necessary
    for i in range(bboxes.shape[0]):
        x_center, y_center, w, h = bboxes[i, :]
        x_min = x_center - w / 2
        y_min = y_center - h / 2
        x_max = x_min + w
        y_max = y_min + h
        if x_min < 0:
            new_x_center = x_center + abs(x_min)/2
            new_w = w - abs(x_min)
            bboxes[i, 0] = new_x_center
            bboxes[i, 2] = new_w
            print("x_min < 0 for entry "+str(i))
        if y_min < 0:
            new_y_center = y_center + abs(y_min)/2
            new_h = h - abs(y_min)
            bboxes[i, 1] = new_y_center
            bboxes[i, 3] = new_h
            print("y_min < 0 for entry "+str(i))
        if x_max > 1:
            new_x_center = x_center - (x_max - 1)/2
            new_w = w - (x_max - 1)
            bboxes[i, 0] = new_x_center
            bboxes[i, 2] = new_w
            print("x_max > 1 for entry "+str(i))
        if y_max > 1:
            new_y_center = y_center - (y_max - 1)/2
            new_h = h - (y_max - 1)
            bboxes[i, 1] = new_y_center
            bboxes[i, 3] = new_h
            print("y_max > 1 for entry "+str(i))

    return bboxes

def define_augmentation_pipe(type_augmentations):
    # build augmentation pipeline, depending on type_augmentations

    transform = A.ReplayCompose([], bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids'], min_visibility=0.3))
    if 'shift' in type_augmentations:
        transform.transforms.append(A.ShiftScaleRotate(p = 0.5, shift_limit= shift_factor, scale_limit = 0, rotate_limit=0, border_mode=cv2.BORDER_CONSTANT, value=0))
    if 'scale' in type_augmentations:
        transform.transforms.append(A.ShiftScaleRotate(p = 0.5, shift_limit= 0, scale_limit = (1-scaling_factors[0], 1-scaling_factors[1]), rotate_limit=0, border_mode=cv2.BORDER_CONSTANT, value=0))
    if 'rotate90' in type_augmentations:
        transform.transforms.append(A.RandomRotate90(p=0.5))
    if 'flip' in type_augmentations:
        transform.transforms.append(A.HorizontalFlip(p=0.5))
    if 'brightness' in type_augmentations:
        transform.transforms.append(A.RandomBrightnessContrast(p=0.25, brightness_limit=brightness_factor, contrast_limit=contrast_factor))
    if 'gaussblur' in type_augmentations:
        transform.transforms.append(A.GaussianBlur(p=0.5, blur_limit=gaussblur_kernel))
    if 'compression' in type_augmentations:
        transform.transforms.append(A.ImageCompression(p=0.1, quality_lower=compression_quality, quality_upper=100))
    if 'blur' in type_augmentations:
        transform.transforms.append(A.Blur(p=0.25, blur_limit=3))
    if 'rotate' in type_augmentations:
        transform.transforms.append(A.Rotate(p=0.5, limit=90, border_mode=cv2.BORDER_CONSTANT, value=0))

    return transform


def build_transform(image = [], label = []):
    transform = define_augmentation_pipe(type_augmentations)

    # visualize one example, if image and label are given
    if type(image) == np.ndarray and type(label) == np.ndarray:

        bboxes = label[:, 1:]
        bboxes = check_boxes(bboxes)
        category_id = label[:, 0].astype(np.int32)

        visualize_image(image, bboxes, category_id, category_id_to_name)
        transformed = transform(image=image, bboxes=bboxes, category_ids=category_id)

        random.seed(14)
        visualize_image(transformed['image'], np.array(transformed['bboxes']), transformed['category_ids'], category_id_to_name)

    return transform

def augment_data(input_dir, number_augmentations, transform, output_dir = False):
    image_dir = os.path.join(input_dir, "images")
    label_dir = os.path.join(input_dir, "labels")

    if output_dir == False:
        image_output_dir = image_dir
        label_output_dir = label_dir
    else:
        image_output_dir = os.path.join(output_dir, "images")
        label_output_dir = os.path.join(output_dir, "labels")

    for image_name in os.listdir(image_dir):
        image = cv2.imread(os.path.join(image_dir, image_name))
        extensionless_name = os.path.splitext(image_name)[0]

        label = np.loadtxt(os.path.join(label_dir, extensionless_name+".txt"), delimiter=" ", ndmin=2)
        
        bboxes = label[:, 1:]
        bboxes = check_boxes(bboxes)
        category_id = label[:, 0].astype(np.int32)

        for i in range(number_augmentations):
            transformed = transform(image=image, bboxes=bboxes, category_ids=category_id)
            
            # save all images and labels, even if no bounding box is left
            cv2.imwrite(os.path.join(image_output_dir, extensionless_name+"_aug"+str(i)+".jpeg"), transformed['image'])
            np.savetxt(os.path.join(label_output_dir, extensionless_name+"_aug"+str(i)+".txt"), np.hstack((np.array(transformed['category_ids']).reshape(-1, 1), np.array(transformed['bboxes']))), delimiter=" ", fmt="%1.5f")
        
        print("image "+extensionless_name+" augmented "+str(number_augmentations)+" times")


# #%%  test augmentation pipeline
# # define paths
# image_dir = os.path.join(output_dir, "train/images")
# label_dir = os.path.join(output_dir, "train/labels")
# train_dir = os.path.join(output_dir, "train")

# # load an exemplary image and label
# image = cv2.imread(os.path.join(image_dir, "06.jpeg"))
# label = np.loadtxt(os.path.join(label_dir, "06.txt"), delimiter=" ", ndmin=2)

# # define augmentation pipeline
# transform = build_transform(image, label)

# #%% augment data
# augment_data(train_dir, number_augmentations, transform)

#%%