"""
image_augmentation.py

This script performs data augmentation on images using various transformations such as color jitter,
Gaussian blur, sharpness adjustment, horizontal flip, and vertical flip. 
It generates augmented images from an existing dataset and saves them to new folders.

"""

from PIL import Image
import os
from torchvision.utils import save_image
import torchvision.transforms as tvts
import random

# List of transformations for data augmentation
transforms=[tvts.ColorJitter(brightness=.5, hue=.3),
            tvts.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)),
            tvts.RandomAdjustSharpness(sharpness_factor=2),
            tvts.RandomHorizontalFlip(p=0.5),
            tvts.RandomVerticalFlip(p=0.5)]

def select_random_imgs(img_list, count):
    """
    Selects a random subset of images from a list.

    Parameters:
        img_list (list): List of image filenames.
        count (int): Number of images to select.

    Returns:
        list: Randomly selected image filenames.
    """
    return random.sample(img_list, count)

def create_foldername(fname):
    """
    Creates a new folder name for augmented images.

    Parameters:
        fname (str): Original folder name.

    Returns:
        str: Augmented folder name.
    """
    names = fname.split("_")
    names.insert(2, "augment")
    return "_".join(names)

def save_original_imgs(source_fname, dest_fname, img_files):
    """
    Saves original images to the destination folder.

    Parameters:
        source_fname (str): Source folder name.
        dest_fname (str): Destination folder name.
        img_files (list): List of image filenames.
    """
    for img in img_files:
        img_name = img.split(".")[0]
        img = open_PILimg(path=os.path.join(source_fname, img))
        save_image(tvts.ToTensor()(img), os.path.join(dest_fname, "{}.jpg".format(img_name)))

def open_PILimg(path):
    """
    Opens an image file using PIL.

    Parameters:
        path (str): File path to the image.

    Returns:
        PIL.Image.Image: PIL image object.
    """
    return Image.open(path).convert("RGB")
    
def augment_folder(root_fname, fname, min_images, augment_count=25):
    """
    Augments images in a folder and saves the augmented images to a new folder.

    Parameters:
        root_fname (str): Root folder containing the source images.
        fname (str): Name of the folder containing the source images.
        min_images (int): Minimum number of images to ensure in the folder.
        augment_count (int): Number of augmented images to generate for each original image.
    """
    onlyfiles = os.listdir(os.path.join(root_fname, fname))
    count = [len(onlyfiles) if len(onlyfiles) <= min_images else min_images]
    random_imgs = select_random_imgs(onlyfiles, count[0])
    augment_fname = create_foldername(root_fname)
    if not os.path.exists(augment_fname):
        os.mkdir(augment_fname)
    
    save_fname = os.path.join(augment_fname, fname)

    if not os.path.exists(save_fname):
        os.mkdir(save_fname)

    save_original_imgs(source_fname=os.path.join(root_fname, fname), 
                       dest_fname=os.path.join(augment_fname, fname), 
                       img_files=onlyfiles)

    for img in random_imgs:
        img_name = img.split(".")[0]
        img = open_PILimg(path=os.path.join(root_fname, fname, img))
        
        for transform in transforms:
            transname = transform.__class__.__name__
            transformed_imgs = [transform(img) for _ in range(augment_count)]
            for i, trans_img in enumerate(transformed_imgs):
                save_image(tvts.ToTensor()(trans_img), os.path.join(save_fname, "trans_img_{}_{}_{}.jpg".format(img_name, transname, i)))

    

if __name__ == "__main__":
    # Example usage
    root_folder = [["GNSS/Train", 10]]
    for root, min_images in root_folder:
        folders = [f for f in os.listdir(root) if os.path.isdir(os.path.join(root, f))]
        for folder in folders:
            augment_folder(root, folder, min_images)


