# convert the numpy file of CIFAR-10-C and CIFAR-100-C to img file
# organize it as CIFAR10-C/class_name/{corrupt_method_name}/{severity}/0000.png
import glob
import os

import numpy as np
from PIL import Image
from tqdm import tqdm

save_root_dir = 'data/images_corrupted'

cifar100_classes_name_list = [
    'apple',  # id 0
    'aquarium_fish',
    'baby',
    'bear',
    'beaver',
    'bed',
    'bee',
    'beetle',
    'bicycle',
    'bottle',
    'bowl',
    'boy',
    'bridge',
    'bus',
    'butterfly',
    'camel',
    'can',
    'castle',
    'caterpillar',
    'cattle',
    'chair',
    'chimpanzee',
    'clock',
    'cloud',
    'cockroach',
    'couch',
    'crab',
    'crocodile',
    'cup',
    'dinosaur',
    'dolphin',
    'elephant',
    'flatfish',
    'forest',
    'fox',
    'girl',
    'hamster',
    'house',
    'kangaroo',
    'computer_keyboard',
    'lamp',
    'lawn_mower',
    'leopard',
    'lion',
    'lizard',
    'lobster',
    'man',
    'maple_tree',
    'motorcycle',
    'mountain',
    'mouse',
    'mushroom',
    'oak_tree',
    'orange',
    'orchid',
    'otter',
    'palm_tree',
    'pear',
    'pickup_truck',
    'pine_tree',
    'plain',
    'plate',
    'poppy',
    'porcupine',
    'possum',
    'rabbit',
    'raccoon',
    'ray',
    'road',
    'rocket',
    'rose',
    'sea',
    'seal',
    'shark',
    'shrew',
    'skunk',
    'skyscraper',
    'snail',
    'snake',
    'spider',
    'squirrel',
    'streetcar',
    'sunflower',
    'sweet_pepper',
    'table',
    'tank',
    'telephone',
    'television',
    'tiger',
    'tractor',
    'train',
    'trout',
    'tulip',
    'turtle',
    'wardrobe',
    'whale',
    'willow_tree',
    'wolf',
    'woman',
    'worm',
]
cifar10_classes_name_list = [
    'airplane',
    'automobile',
    'bird',
    'cat',
    'deer',
    'dog',
    'frog',
    'horse',
    'ship',
    'truck'
]

for dataset_name in ['cifar10']:

    if dataset_name == 'cifar10':
        data_dir = 'data/images_corrupted/CIFAR-10-C'
    else:
        data_dir = 'data/images_corrupted/CIFAR-100-C'
    imglist_path = f'data/benchmark_imglist/{dataset_name}/test_{dataset_name}c_full.txt'

    full_list_f = open(imglist_path, 'w')
    svt_list_f = []
    for s in range(1, 6):
        svt_list_f.append(open(f'data/benchmark_imglist/{dataset_name}/test_{dataset_name}c_s{s}', 'w'))

    for file in glob.glob(f'{data_dir}/*.npy'):
        print(file)
        data = np.load(file)
        label_data = np.load(os.path.join(data_dir, 'labels.npy'))
        method_name = os.path.basename(file).split('.')[0]    # corruption method name

        if method_name == 'labels':
            continue

        print('Saving images for the corruption', method_name, dataset_name)
        for i, img in enumerate(tqdm(data)):
            severity = int(i / 10000)+1
            idx = i % 10000
            # print(i, severity, idx)
            label = label_data[i]
            if dataset_name == 'cifar100':
                cls_name = cifar100_classes_name_list[label]
            else:
                cls_name = cifar10_classes_name_list[label]
            save_sub_dir = os.path.join(f'{dataset_name}c', cls_name, method_name, str(severity))
            # save
            os.makedirs(os.path.join(save_root_dir, save_sub_dir), exist_ok=True)
            Image.fromarray(img).save(os.path.join(save_root_dir, save_sub_dir, f'{str(idx).zfill(4)}.png'))
            full_list_f.write(f"{str(os.path.join(save_sub_dir, f'{str(idx).zfill(4)}.png'))} {label}\n")
            svt_list_f[severity-1].write(f"{str(os.path.join(save_sub_dir, f'{str(idx).zfill(4)}.png'))} {label}\n")

    full_list_f.close()
    for f in svt_list_f:
        f.close()







