import numpy as np
import pandas as pd
import tensorflow as tf
import os
from PIL import Image


gesture_2012_path = "resources/gesture_2012/"
#the above datasat path contains five directories: part_1, part_2, part_3, part_4, part_5
# each of these subdirectories contains images
# all these images are in png format and have different sizes
# the following snippet looks at all these images in each subdirectory
# padding is added to images so that
# all images in each subdirectory have the same size of 590 * 660 (max_width * max_height)
# the padding is added such that the original image is centered in the new image
def padding_gesture_2012_images():
    dst_directory = gesture_2012_path + 'padded_images/'
    os.makedirs(dst_directory, exist_ok=True)
    for i in range(1,6):
        src_directory = gesture_2012_path + 'part_' + str(i) + '/'
        for filename in os.listdir(src_directory):
            if filename.endswith(".png"):
                img = Image.open(src_directory + filename)
                width, height = img.size
                max_height = 660
                max_width = 660
                new_im = Image.new("RGB", (max_width, max_height))
                new_im.paste(img, (int((max_width - width) / 2), int((max_height - height) / 2)))
                new_im.save(dst_directory + filename)

# the following line needs to be executed only once
# it pads all the images in the gesture_2012 dataset
# and saves them in the directory gesture_2012/padded_images
# padding_gesture_2012_images()

#the following function reads the data from the padded_images directory,
# and resizes the images to the given img_shape (default is 256 * 256 * 3) without losing quality
# and saves them in the directory gesture_2012/resized_padded_images
def resize_padded_gesture_2012_images(img_shape=(256, 256, 3)):
    dst_directory = gesture_2012_path + 'resized_padded_images/'
    os.makedirs(dst_directory, exist_ok=True)
    os.makedirs(dst_directory + 'train/', exist_ok=True)
    os.makedirs(dst_directory + 'test/', exist_ok=True)
    for filename in os.listdir(gesture_2012_path + 'padded_images/train'):
        if filename.endswith(".png"):
            img = Image.open(gesture_2012_path + 'padded_images/train/' + filename)
            img = img.resize(img_shape[:2], Image.ANTIALIAS)
            img.save(dst_directory + 'train/' + filename)
    for filename in os.listdir(gesture_2012_path + 'padded_images/test'):
        if filename.endswith(".png"):
            img = Image.open(gesture_2012_path + 'padded_images/test/' + filename)
            img = img.resize(img_shape[:2], Image.ANTIALIAS)
            img.save(dst_directory + 'test/' + filename)


# the following function creates train and test subdirectories
# and moves 95% of the data to train and 5% to test
def create_train_and_test_subdirectories():
    one_hot_dict = {str(i): i for i in range(10)}
    one_hot_dict.update({chr(i): i - 87 for i in range(97, 123)})
    data_path = gesture_2012_path + 'resized_padded_images/'
    # the following line checks if data_path/train and data_path/test exist
    # if they do not exist, it creates them and moves 95% of the data to train and 5% to test
    if not os.path.exists(data_path + 'train') and not os.path.exists(data_path + 'test'):
        os.makedirs(data_path + 'train', exist_ok=True)
        os.makedirs(data_path + 'test', exist_ok=True)
        for filename in os.listdir(data_path):
            if filename.endswith(".png"):
                label = filename.split(sep='_')[1]
                required_index = filename.find('_', filename.find('_') + 1)
                new_file_name = filename[:required_index + 1] + str(one_hot_dict[label]) + '_' + filename[
                                                                                                  required_index + 1:]
                if np.random.rand() < 0.95:
                    os.rename(data_path + filename, data_path + 'train/' + new_file_name)
                else:
                    os.rename(data_path + filename, data_path + 'test/' + new_file_name)


# the following function reads the data from the padded_images directory
# and returns a tensorflow dataset. The dataset is shuffled and batched.
#  because the size of the dataset exceeds the memory of the GPU, the dataset is not cached

def gesture_2012_loader(split, num_labels=36, batch_size=64, seed=0, include_labels=False, to_one_hot=True):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    create_train_and_test_subdirectories()
    data_path = gesture_2012_path + 'resized_padded_images/'
    if split == 'train':
        directory = data_path + 'train/'
    elif split == 'test':
        directory = data_path + 'test/'
    else:
        raise ValueError('split should be either train or test')
    filenames = tf.data.Dataset.list_files(directory + '*.png', shuffle=True, seed=seed)
    one_hot_dict = {str(i): i for i in range(10)}
    one_hot_dict.update({chr(i): i - 87 for i in range(97, 123)})
    dataset = filenames.map(lambda x: (tf.io.read_file(x), tf.strings.to_number(tf.strings.split(tf.strings.split(x, sep='/')[4], sep='_')[2], out_type=tf.int32)))
    dataset = dataset.map(lambda x, y: (tf.image.decode_png(x, channels=3), y))
    # dataset = dataset.map(lambda x, y: (tf.image.resize(x, img_shape[:2]), y))
    # Line below makes the images have values between -1 and 1
    dataset = dataset.map(lambda x, y: (tf.cast(x, tf.float32) / 127.5 - 1.0, y))
    if to_one_hot:
        dataset = dataset.map(lambda x, y: (x, tf.one_hot(y, num_labels)))
    if not include_labels:
        # Line below gets rid of the labels for now
        dataset = dataset.map(lambda x, y: x)
    #the following line is used to have the structure: ((image, label), image)
    # dataset = dataset.map(lambda x, y: ((x, y), x))
    dataset = dataset.shuffle(10000, seed=seed)
    dataset = dataset.batch(batch_size)
    return dataset

def load_dataset(dataset_name, split, batch_size, seed=0, include_labels=False):
    if dataset_name == "gesture_2012":
        return gesture_2012_loader(
            split=split,
            batch_size=batch_size,
            seed=seed,
            include_labels=include_labels
        )
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}.")


def count_num_of_each_label_in_gesture_train():
    x_train = gesture_2012_loader(split='train', batch_size=64 ,include_labels=True, to_one_hot=False)

    label_dict = {key: 0 for key in range(36)}

    for data, labels in x_train:
        for label in labels:
            label_dict[label.numpy()] += 1

    print(label_dict)
    print("Value: ", sum(label_dict.values()))
    print("Length", len(x_train))
    return label_dict

