import os

import click
import torch
import tqdm
from torch.utils.data import DataLoader, Subset
from torchvision.utils import save_image

import training.dataset


@click.command()
@click.option('--generated', 'generated_path', help='Path to the generated images', metavar='PATH',             type=str, required=True)
@click.option('--loss', 'loss_path',           help='Path to the loss file', metavar='PATH',                    type=str, required=True)
@click.option('--threshold', 'loss_threshold', help='Loss threshold for membership inference', metavar='FLOAT', type=float, required=True)
@click.option('--time',                        help='Time step for membership inference', metavar='INT',        type=int, required=True)
@click.option('--outdir',                      help='Path to save the attack results', metavar='DIR',           type=str, required=True)
@click.option('--batch', 'batch_size',         help='Batch size', metavar='INT',                                type=int, default=1024, show_default=True)
def main(generated_path, loss_path, loss_threshold, time, outdir, batch_size):
    dataset = training.dataset.ImageFolderDataset(path=generated_path)
    loss = torch.load(loss_path)[:, time]
    assert len(dataset) == len(loss)

    indices = torch.arange(len(dataset))
    indices = indices[loss <= loss_threshold]
    dataset = Subset(dataset, indices)
    loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)

    counter = 0
    os.makedirs(outdir, exist_ok=True)
    with tqdm.tqdm(total=len(dataset)) as pbar:
        for x, _, _ in loader:
            x = x.float() / 255.0
            for i in range(len(x)):
                save_image(x[i], os.path.join(outdir, f'{counter:06d}.png'))
                counter += 1
                pbar.update(1)


if __name__ == "__main__":
    main()
