import os
import torch
import pyrootutils
from tqdm import tqdm

from argument import args
from utils.data_utils import load_data
from utils.syn_utils_img2img import ImageSynthesizer


if __name__ == '__main__':
    torch.manual_seed(args.seed)

    print('phase:', args.phase)

    original_dataset = load_data(
        args=args, resize_only=True, mem_flag=False, trainset_only=True
    )
    class_indices = original_dataset.classes
    print(f'total data number: {len(original_dataset)}')
    synthesizer = ImageSynthesizer(args)
    synthesizer.init_img2img()

    images = []
    for image_index, (image, label) in enumerate(original_dataset):
        if (image_index + 1) % args.ipc == 0:
            images.append(image)
            label = label.item()
            class_index = class_indices[label]
            print((image_index + 1) // args.ipc, class_index)

            with tqdm(total=args.target_ipc) as pbar:
                for target_index in range(args.target_ipc):
                    new_data = synthesizer.sample_img2img(
                        images[target_index], class_index
                    )
                    if not os.path.exists(f'{args.save_path}/{class_index}'):
                        os.makedirs(f'{args.save_path}/{class_index}')
                    new_data.save(f'{args.save_path}/{class_index}/{target_index}.png')
                    pbar.update()
            images = []
        else:
            images.append(image)
