import Augmentor
from PIL import Image
import os
import numpy as np
import torch
from torchvision.utils import save_image
import torchvision.transforms as tvts
import random
from os import listdir
from os.path import isfile, join


def check_consistency():
    arr = np.array(Image.open(os.path.join('augmented_imgs','output', 'augmented_imgs_original_img.jpg_9b3b2501-0fad-466e-8f2e-891478d0aff2.jpg')).convert("RGB"))
    print("arr shape", arr.shape)
    img = Image.fromarray(arr.astype("uint8")) 
    img.save(os.path.join("augmented_imgs", 'PIL_IMG.jpg'))
    ti = tvts.ToTensor()(img)
    print(ti.shape)
    print(ti)
    save_image(tvts.ToTensor()(img), os.path.join("augmented_imgs", "Tensor_IMG.png"))


def jitter_img(img, count):
    jitter = tvts.ColorJitter(brightness=.5, hue=.3)
    jittered_imgs = [jitter(img) for _ in range(count)]
    return jitter_img


def gaussian_blur(img, count):
    blurrer = tvts.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.))
    blurred_imgs = [blurrer(img) for _ in range(count)]
    return blurred_imgs

def random_adjust_sharpness(img, count):
    sharpness_adjuster = tvts.RandomAdjustSharpness(sharpness_factor=2)
    sharpened_imgs = [sharpness_adjuster(img) for _ in range(count)]


def random_rotation(img, count):
    rotater = tvts.RandomRotation(degrees=(0, 180))
    rotated_imgs = [rotater(img) for _ in range(count)]

def random_horizontal_flip(img, count):
    hflipper = tvts.RandomHorizontalFlip(p=0.5)
    transformed_imgs = [hflipper(img) for _ in range(count)]

def random_vertical_flip(img, count):
    vflipper = tvts.RandomVerticalFlip(p=0.5)
    transformed_imgs = [vflipper(img) for _ in range(count)]

transforms=[tvts.ColorJitter(brightness=.5, hue=.3),
            tvts.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)),
            tvts.RandomAdjustSharpness(sharpness_factor=2),
            #tvts.RandomRotation(degrees=(0, 180)),
            tvts.RandomHorizontalFlip(p=0.5),
            tvts.RandomVerticalFlip(p=0.5)]

def select_random_imgs(img_list, count):
    return random.sample(img_list, count)

def create_foldername(fname):
    names = fname.split("_")
    names.insert(2, "augment")
    return "_".join(names)

def save_original_imgs(source_fname, dest_fname,  img_files):
    for img in img_files:
        img_name = img.split(".")[0]
        img = open_PILimg(path=os.path.join(source_fname, img))
        save_image(tvts.ToTensor()(img), os.path.join(dest_fname, "{}.jpg".format(img_name)))


def open_PILimg(path):
    return Image.open(path).convert("RGB")

def save_PILimg(path, img):
    save_image(tvts.ToTensor()(img),  "{}.jpg".format(path))
    
    
def augment_folder(root_fname, fname, min_images, augment_count=25):
    
    onlyfiles = os.listdir(os.path.join(root_fname, fname))
    count = [len(onlyfiles) if len(onlyfiles) <= min_images else min_images]
    print("count is {}".format(count[0]))
    random_imgs = select_random_imgs(onlyfiles, count[0])
    augment_fname = create_foldername(root_fname)
    if not os.path.exists(augment_fname):
        os.mkdir(augment_fname)
    
    save_fname = os.path.join(augment_fname, fname)

    if not os.path.exists(save_fname):
        os.mkdir(save_fname)

    save_original_imgs(source_fname=os.path.join(root_fname, fname), 
                       dest_fname=os.path.join(augment_fname, fname), 
                       img_files=onlyfiles)

    for img in random_imgs:
        img_name = img.split(".")[0]
        img = open_PILimg(path=os.path.join(root_fname, fname, img))
        
        for transform in transforms:
            transname = transform.__class__.__name__
            transformed_imgs = [transform(img) for _ in range(augment_count)]
            for i, trans_img in enumerate(transformed_imgs):
                save_image(tvts.ToTensor()(trans_img), os.path.join(save_fname, "trans_img_{}_{}_{}.jpg".format(img_name, transname, i)))

    
    print("no of images in {} is {}".format(save_fname, len(os.listdir(save_fname))))

import shutil

def select_random_files(count, source_fname, dest_fname):
    files = os.listdir(source_fname)
    files = select_random_imgs(files, count=count)

    # Iterate through the files and copy them to the destination folder
    for file in files:
        source_path = os.path.join(source_fname, file)
        destination_path = os.path.join(dest_fname, file)
        shutil.copy(source_path, destination_path)
        print(f"Copied {file} to {dest_fname}")


root_folder = [["DarcyCIL_DB_TR", 10], ["DarcyCIL_DB_TE", 2]]
for root, min_images in root_folder:
    folders = [f for f in os.listdir(root) if os.path.isdir(os.path.join(root, f))]
    for folder in folders:
        augment_folder(root, folder, min_images)


