# coding: utf-8

import os,argparse,glob,json
import numpy as np
from torchvision.datasets.imagenet import load_meta_file
from torchvision.io import read_image,ImageReadMode
from torchvision.transforms import CenterCrop,Normalize,Compose


def main(orig_root, save_root, resolution):
	transforms = Compose([
					CenterCrop(size=[resolution,resolution]),
					Normalize(mean=[0.4815, 0.4578, 0.4082], std=[0.2686, 0.2613, 0.2758])])
	for sub_path in glob.glob('*/*/*.JPEG', root_dir=orig_root):
		img = read_image(path=os.path.join(orig_root, sub_path), mode=ImageReadMode.RGB)
		img = img.float()/255 # Normalize
		img = transforms(img)
		save_path = os.path.join(save_root, os.path.splitext(sub_path)[0]+'.npy')
		os.makedirs(os.path.dirname(save_path), exist_ok=True)
		np.save(save_path, img.numpy())
	wnid_to_classes = load_meta_file(orig_root)[0]
	classes = list(wnid_to_classes.values())
	with open(os.path.join(save_root, 'ann.json'), 'w') as f:
		json.dump(dict(wnid_to_classes=wnid_to_classes,classes=classes),f)

if __name__=='__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('orig_root', type=str, help='Path to the root directory where original imgs are stored.')
	parser.add_argument('save_root', type=str, help='Path to the root directory where preprocessed data are saved.')
	parser.add_argument('resolution', type=int, help='Resolution of the center-cropped imgs.')
	args = parser.parse_args()

	main(**vars(args))