import os, argparse
import numpy as np
import json
import utils
import shutil
from tqdm import tqdm

from research_pool.config import imagenet_split


parser = argparse.ArgumentParser()
parser.add_argument('--data', required=True, help='Path to train directory')
parser.add_argument('--out_root', required=True, type=str, help="Path to out (root) directory")
opt = vars(parser.parse_args())


np.random.seed(9)

generator_home = os.path.join(opt['out_root'], 'generator_train')
victim_home = os.path.join(opt['out_root'], 'victim_train')

def makedirs(dl: list):
    for d in dl:
        if not os.path.isdir(d):
            print(f"Create {d}")
            os.makedirs(d)
            
makedirs([generator_home, victim_home])
        
class_folders = os.listdir(opt['data'])

for cf in tqdm(class_folders):
    generator_cf = os.path.join(generator_home, cf)
    victim_cf = os.path.join(victim_home, cf)
    makedirs([generator_cf, victim_cf])
    cf_full = os.path.join(opt['data'], cf)
    images = os.listdir(cf_full)
    
    np.random.shuffle(images)
    n = len(images)
    n_victim = int(np.ceil(imagenet_split * n))
    
    for im in tqdm(images[:n_victim]):
        im_full = os.path.join(cf_full, im)
        victim_im = os.path.join(victim_cf, im)
        shutil.copy(im_full, victim_im)
        
    for im in tqdm(images[n_victim:]):
        im_full = os.path.join(cf_full, im)
        gen_im = os.path.join(generator_cf, im)
        shutil.copy(im_full, gen_im)


print("Done.")
