import argparse
import os

import torch
from diffusers import StableDiffusionPipeline


parser = argparse.ArgumentParser(description="Inference")
parser.add_argument(
    "--model_path",
    type=str,
    default=None,
    required=True,
    help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="./test-infer/",
    help="The output directory where predictions are saved",
)
parser.add_argument(
    "--inference_prompts",
    type=str,
    default="a photo of sks person;a dslr portrait of sks person;a photo of sks person in front of eiffel tower",
    help="The prompt used to generate images at inference.",
)
parser.add_argument(
    "--num_samples",
    type=int,
    default=None,
    help="The number of samples to generate for each prompt.",
)
args = parser.parse_args()

if __name__ == "__main__":
    os.makedirs(args.output_dir, exist_ok=True)
    prompts = [p.strip() for p in args.inference_prompts.split(";")]

    # # define prompts
    # prompts = [
    #     "a photo of a sks person",
    #     "a dslr portrait of sks person",
    #     "a close-up photo of sks person riding a bike",
    #     "a photo of sks person in front of eiffel tower",
    #     "a selfie photo of sks person on top of mount fuji",
    # ]

    # create & load model
    pipe = StableDiffusionPipeline.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        safety_checker=None,
        local_files_only=True,
        use_safetensors=False
    ).to("cuda")
    pipe.enable_xformers_memory_efficient_attention()


    torch.manual_seed(1234)
    for prompt in prompts:
        print(">>>>>>", prompt)
        norm_prompt = prompt.lower().replace(",", "").replace(" ", "_")
        out_path = f"{args.output_dir}/{norm_prompt}"
        os.makedirs(out_path, exist_ok=True)
        if args.num_samples:
            images = pipe([prompt] * args.num_samples, num_inference_steps=50, guidance_scale=7.5).images
            for idx, image in enumerate(images):
                image.save(f"{out_path}/{idx}.png")
        else:
            for i in range(2):
                images = pipe([prompt] * 15, num_inference_steps=50, guidance_scale=7.5).images
                for idx, image in enumerate(images):
                    image.save(f"{out_path}/{i}_{idx}.png")
    del pipe
    torch.cuda.empty_cache()
