import glob
import os
from tqdm import tqdm
import json

import numpy as np
from random import sample
import torchvision.transforms as transforms

import sys
sys.path.append('..')
from data_generate.split_generator import SplitGenerator

if __name__ == "__main__":

    # load the config file
    config_name = 'prompt.json'
    jsonfile = open(os.path.join('../config/5-shot', config_name))
    config = json.loads(jsonfile.read())

    dest_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), 'miniQuickDraw')
    label_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'miniQuickDraw_raw'), "variant.txt")
    npy_data_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'miniQuickDraw_raw'))
    raw_data_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'miniQuickDraw_raw'), 'data')
    few_data_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'miniQuickDraw_raw'), 'few_data')

    if not os.path.exists(raw_data_dir):
        
        full_cls_file = glob.glob(npy_data_dir+'/*')
        
        # set some hyperparameter
        num_ch = 3
        num_inst_per_cls_to_sample = 1000

        for npyfile in tqdm(full_cls_file, desc='Generating mini_quickdraw'):
            # each npy file is a class (nsample in this class x 784)
            cls_name = os.path.basename(os.path.splitext(os.path.normpath(npyfile))[0])
            npy_cls = np.load(npyfile)
            if num_inst_per_cls_to_sample is None:
                num_inst_per_cls_to_sample = npy_cls.shape[0]
            for idx, imgvec in enumerate(list(npy_cls[sample(range(npy_cls.shape[0]), num_inst_per_cls_to_sample), :])):
                # numpy 2-d image
                img_np = imgvec.reshape(28, 28)
                # masks
                if num_ch == 3:
                    bg_mask = np.repeat(np.expand_dims((img_np == 0), axis=0), repeats=3, axis=0)
                    obj_mask = np.invert(bg_mask)
                    # expand to 3 channels
                    img_np = np.repeat(np.expand_dims(img_np, axis=0), repeats=3, axis=0)
                elif num_ch == 1:
                    bg_mask = img_np == 0
                    obj_mask = np.invert(bg_mask)
                else:
                    raise ValueError('num_ch either 1 or 3')
                np.place(img_np, mask=bg_mask, vals=255)
                np.place(img_np, mask=obj_mask, vals=0)
                # transform to pil image
                pil_img = transforms.ToPILImage()(np.transpose(img_np, (1, 2, 0)))
                img_dir = os.path.join(raw_data_dir, cls_name)
                if not os.path.exists(img_dir):
                    os.makedirs(img_dir)
                img_name = '{}_{}.png'.format(cls_name, str(idx))
                pil_img.save(os.path.join(img_dir, img_name), format='PNG')
    

    #* split dataset
    split_aircraft = SplitGenerator(data_dir=raw_data_dir, dest_dir=dest_dir, few_data_dir=few_data_dir, split_dir=None, verbose='quickdraft')
    split_aircraft.split_train_val_test(nclass_train=64, nclass_val=16, nclass_test=20, save_split_npy=True)