from cProfile import label
import os
import json
from PIL import Image
from tqdm import tqdm

import sys
sys.path.append("..")

from data_generate.split_generator import SplitGenerator

def prepare_images_and_classes(image_folder_dir, label_only_file_dir, image_label_file_dir_list, crop_bottom_pixel=20):
    # create class folders
    lbl_txt_file = open(label_only_file_dir)
    label_list = lbl_txt_file.read().split('\n')[:-1]
    for lbl in label_list:
        if '/' in lbl:
            lbl = lbl.replace('/', '-', 1)
        os.makedirs(os.path.join(image_folder_dir, lbl))

    # each str in list is of the form: 'image label'
    image_label_list = []
    for image_label_file_dir in image_label_file_dir_list:
        img_lbl_txt_file = open(image_label_file_dir)
        image_label_list += img_lbl_txt_file.read().split('\n')[:-1]

    # crop away bottom copyright info and move to class folders
    for img_lbl in tqdm(image_label_list, desc='Preparing images'):
        img, lbl = img_lbl.split(sep=' ', maxsplit=1)
        img_file_name = img + '.jpg'

        if '/' in lbl:
            lbl = lbl.replace('/', '-', 1)

        ori_img_dir = os.path.join(image_folder_dir, img_file_name)
        dest_img_dir = os.path.join(os.path.join(image_folder_dir, lbl), img_file_name)

        # crop image before moving to class folder
        im = Image.open(ori_img_dir)
        im_width, im_height = im.size
        im_crop = im.crop((0, 0, im_width, im_height - crop_bottom_pixel))
        # save cropped image in class folder
        im_crop.save(fp=dest_img_dir)

        # remove original non-cropped image file
        os.remove(ori_img_dir)

if __name__ == "__main__":
    # load the config file
    config_name = 'prompt.json'
    jsonfile = open(os.path.join('../config/5-shot', config_name))
    config = json.loads(jsonfile.read())

    dest_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), 'aircraft')
    label_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'aircraft_raw'), "variant.txt")
    raw_data_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'aircraft_raw'), 'data')
    few_data_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'aircraft_raw'), 'few_data')

    # organise images into class folders, since it has been done
    if False:
        prepare_images_and_classes(
            image_folder_dir=raw_data_dir,
            label_only_file_dir=os.path.join(label_dir, 'variants.txt'),
            image_label_file_dir_list = [
                os.path.join(label_dir, 'images_variant_trainval.txt'), os.path.join(label_dir, 'images_variant_test.txt')
            ],
            crop_bottom_pixel=20
        )

    #* split dataset
    split_aircraft = SplitGenerator(data_dir=raw_data_dir, dest_dir=dest_dir, few_data_dir=few_data_dir, split_dir=None, verbose='aircraft')
    split_aircraft.split_train_val_test(nclass_train=64, nclass_val=16, nclass_test=20, save_split_npy=True, max_num=100, option='aircraft')

