import glob
import os
from tqdm import tqdm, trange
import json
import pickle
from PIL import Image
import shutil


import numpy as np
from random import sample
import torchvision.transforms as transforms

import sys
sys.path.append('..')
from data_generate.split_generator import SplitGenerator


if __name__ == "__main__":

    config_name = 'prompt.json'
    jsonfile = open(os.path.join('../config/5-shot', config_name))
    config = json.loads(jsonfile.read())

    dest_dir = os.path.join(os.path.join(config['data_dir'], '5-shot'), 'tiered_imagenet')
    label_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'tiered_imagenet_raw'), "variant.txt")
    pkl_data_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'tiered_imagenet_raw'))
    raw_data_dir = os.path.join(os.path.join(os.path.join(config['data_dir'], 'raw'), 'tiered_imagenet_raw'), 'data')

    stages = ['train', 'val', 'test']
    
    if not os.path.exists(raw_data_dir):
        os.mkdir(raw_data_dir)

    if os.path.exists(dest_dir):
        # clean the previous split
        shutil.rmtree(dest_dir)
    os.makedirs(dest_dir)

    for stage in stages:
        
        print(stage)
        image_file_name = os.path.join(pkl_data_dir, stage+'_images.npz')
        label_file_name = image_file_name[:-10]+'labels.pkl'
        
        data = np.load(image_file_name, allow_pickle=True)['images']
        with open(label_file_name, 'rb') as f:
            labels = pickle.load(f)['labels']
        
        cls_paths = []
        
        label_keys = set(labels)
        begin_idx = 0
        
        for label_key in tqdm(label_keys):
            label_name = 'C%04d'%label_key
            label_num = labels.count(label_key)
            
            cls_path = os.path.join(raw_data_dir, label_name)
            cls_paths.append(cls_path)
            if not os.path.exists(cls_path):
                os.mkdir(cls_path)
            
            

            for idx in range(label_num):
                
                image = data[idx+begin_idx]
                pil_image = transforms.ToPILImage()(image)
                pil_image = transforms.Resize([config['img_resize'], config['img_resize']])(pil_image)
                pil_image.save(os.path.join(cls_path, 'C%04d%04d.jpg'%(label_key, idx)))
            
            
            begin_idx += label_num
        
        np.save(os.path.join(dest_dir, 'meta'+stage+'.npy'), cls_paths)

