import os
import torch
import pyrootutils
from tqdm import tqdm
from torchvision.utils import save_image

from argument import args
from utils.data_utils import load_data
from utils.syn_utils_dit import ImageSynthesizer


if __name__ == '__main__':
    torch.manual_seed(args.seed)

    print('phase:', args.phase)

    original_dataset = load_data(
        args=args, resize_only=True, mem_flag=False, trainset_only=True
    )
    class_indices = original_dataset.classes
    original_classes = original_dataset.original_classes
    print(f'total data number: {len(original_dataset)}')
    synthesizer = ImageSynthesizer(args)

    with tqdm(total=len(original_dataset)) as pbar:
        for image_index, (image, label) in enumerate(original_dataset):
            label = label.item()
            class_index = class_indices[label]
            original_label = original_classes[label]

            new_data = synthesizer.sample(original_label, class_index, device='cuda')
            if not os.path.exists(f'{args.save_path}/{class_index}'):
                os.makedirs(f'{args.save_path}/{class_index}')
            save_image(new_data, f'{args.save_path}/{class_index}/{image_index % args.target_ipc}.png', normalize=True, value_range=(-1, 1))

            pbar.update()
