import argparse
from diffusion_model import DiffusionModel



def parse_arguments():
    """Returns parsed arguments"""
    parser = argparse.ArgumentParser(description="Sample images from diffusion model")
    parser.add_argument("--checkpoint_name", type=str, default='pretrained_mnist_checkpoint_49.pth', help="Checkpoint name of diffusion model")  
    parser.add_argument("--n_samples", type=int, default=100, help="Number of samples to generate")
    parser.add_argument("--n_images_per_row", type=int, default=10, help="Number of images each row contains in the grid")
    parser.add_argument("--device", type=str, default=None, help="GPU device to use")
    parser.add_argument("--timesteps", type=int, default=None, help="Total timesteps for sampling")
    parser.add_argument("--beta1", type=float, default=None, help="Hyperparameter for DDPM")
    parser.add_argument("--beta2", type=float, default=None, help="Hyperparameter for DDPM")
    return parser.parse_args()

if __name__=="__main__":

    #for name in range(500,5500,500): 
    #num_list = [100,500,1000,2000,5000,10000]
    num_list = [200]
    exper_name = 'fashion_mnist'      #fashion_mnist
    for name in num_list:  
        if name == 60000:
            total_num = 10000
            save_dir = 'generated_images/' + exper_name +'/samples60000'
        else:
            total_num = name
            save_dir = 'generated_images/' + exper_name + '/samples' + str(total_num)
        args = parse_arguments()
        round = 1
        args.n_samples = total_num//round
        args.device = 'cuda:1'


        for round_i in range(round):
            args.checkpoint_name =  exper_name + '/pretrained_mnist_sample_' + str(name) + '.pth'
            diffusion_model = DiffusionModel(device=args.device, checkpoint_name=args.checkpoint_name)
            diffusion_model.generate(save_dir,args.n_samples, round_i,args.n_images_per_row, args.timesteps, args.beta1, args.beta2)


#python3 sample.py pretrained_mnist_checkpoint_49.pth --n_samples 400 --n_images_per_row 20
#https://github.com/byrkbrk/conditional-ddpm