import os
import json
from PIL import Image
from tqdm import tqdm

from data_generate.split_generator import SplitGenerator

def prepare_images_and_classes(image_folder_dir, label_only_file_dir, image_label_file_dir_list, crop_bottom_pixel=20):
    # create class folders
    lbl_txt_file = open(label_only_file_dir)
    label_list = lbl_txt_file.read().split('\n')[:-1]
    for lbl in label_list:
        if '/' in lbl:
            lbl = lbl.replace('/', '-', 1)
        os.makedirs(os.path.join(image_folder_dir, lbl))

    # each str in list is of the form: 'image label'
    image_label_list = []
    for image_label_file_dir in image_label_file_dir_list:
        img_lbl_txt_file = open(image_label_file_dir)
        image_label_list += img_lbl_txt_file.read().split('\n')[:-1]

    # crop away bottom copyright info and move to class folders
    for img_lbl in tqdm(image_label_list, desc='Preparing images'):
        img, lbl = img_lbl.split(sep=' ', maxsplit=1)
        img_file_name = img + '.jpg'

        if '/' in lbl:
            lbl = lbl.replace('/', '-', 1)

        ori_img_dir = os.path.join(image_folder_dir, img_file_name)
        dest_img_dir = os.path.join(os.path.join(image_folder_dir, lbl), img_file_name)

        # crop image before moving to class folder
        im = Image.open(ori_img_dir)
        im_width, im_height = im.size
        im_crop = im.crop((0, 0, im_width, im_height - crop_bottom_pixel))
        # save cropped image in class folder
        im_crop.save(fp=dest_img_dir)

        # remove original non-cropped image file
        os.remove(ori_img_dir)


if __name__ == "__main__":
    # load config file
    config_name = 'ptl_bomla_lam1.json'
    jsonfile = open(os.path.join('./config/la_seqdataset', config_name))
    config = json.loads(jsonfile.read())

    split_dir = os.path.join(os.path.join(config['data_dir'], config['split_folder']), 'aircraft')
    label_dir = os.path.join(config['data_dir'], 'aircraft_label')
    dest_dir = os.path.join(config['data_dir'], 'aircraft')

    # organise images into class folders
    prepare_images_and_classes(
        image_folder_dir=dest_dir,
        label_only_file_dir=os.path.join(label_dir, 'variants.txt'),
        image_label_file_dir_list = [
            os.path.join(label_dir, 'images_variant_trainval.txt'), os.path.join(label_dir, 'images_variant_test.txt')
        ],
        crop_bottom_pixel=20
    )

    split_aircraft = SplitGenerator(
        data_dir=os.path.join(config['data_dir'], 'aircraft'), dest_dir=dest_dir, split_dir=split_dir,
        back_eval_raw=False, supercls_raw=False, supercls_split=config['aircraft']['supercls']
    )
    split_aircraft.split_train_val_test(nclass_train=64, nclass_val=16, save_split_npy=True, csv_save_form=None)
