from utils.utils import *
import torch
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, required=False, default=None)
parser.add_argument('--target_concept', type=str, default=None)
parser.add_argument('--concept_type', type=str, default=None)
parser.add_argument('--train_method', type=str, default="xattn_q")
parser.add_argument('--save_folder', type=str, default="./images")
parser.add_argument('--num_images', type=int, default=30)

args = parser.parse_args()

def infer_images(args):
    model_path = args.model_path
    train_method = args.train_method

    diffuser = StableDiffuser(scheduler='DDIM').to('cuda:0')

    seed = 42
    print(seed)

    if args.concept_type == "art":
        prompt = f"a painting in the style of {args.target_concept}"
    else:
        prompt = f"a photo of {args.target_concept}"
        
    save_folder = args.save_folder

    os.makedirs(save_folder, exist_ok=True)
    num_images = args.num_images
    generator = torch.Generator().manual_seed(seed)
    
    if model_path:
        finetuner = FineTunedModel(diffuser, train_method=train_method)
        finetuner.load_state_dict(torch.load(model_path))
        with finetuner:
            for i in range(num_images):
                images = diffuser(prompt,
                            img_size=512,
                            n_steps=50,
                            n_imgs=1,
                            generator=generator,
                            guidance_scale=7.5
                            )
                images[0][0].save(f"{save_folder}/{i}.png")
    else:
        for i in range(num_images):
            images = diffuser(prompt,
                        img_size=512,
                        n_steps=50,
                        n_imgs=1,
                        generator=generator,
                        guidance_scale=7.5
                        )
            images[0][0].save(f"{save_folder}/{i}.png")
        
            
if __name__ == "__main__":
    infer_images(args)

