import argparse
import os

import torch
from diffusers import StableDiffusionPipeline


parser = argparse.ArgumentParser(description="Inference")
parser.add_argument(
    "--model_path",
    type=str,
    default=None,
    required=True,
    help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
    "--lora_path",
    type=str,
    default=None,
    required=True,
    help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="./test-infer/",
    help="The output directory where predictions are saved",
)
parser.add_argument(
    "--inference_prompts",
    type=str,
    default="a photo of sks;a dslr portrait of sks;a photo of sks looking at the mirror",
    help="The prompt used to generate images at inference.",
)
args = parser.parse_args()

if __name__ == "__main__":
    os.makedirs(args.output_dir, exist_ok=True)
    prompts = args.inference_prompts.split(";")

    # # define prompts
    # prompts = [
    #     "a photo of a sks person",
    #     "a dslr portrait of sks person",
    #     "a close-up photo of sks person riding a bike",
    #     "a photo of sks person in front of eiffel tower",
    #     "a selfie photo of sks person on top of mount fuji",
    # ]

    # create & load model
    pipe = StableDiffusionPipeline.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        safety_checker=None,
    ).to("cuda")

    pipe.load_textual_inversion(args.lora_path)

    pipe.enable_xformers_memory_efficient_attention()


    torch.manual_seed(1234)
    for prompt in prompts:
        prompt = prompt.replace("sks person", "<sks-person>")
        print(">>>>>>", prompt)
        norm_prompt = prompt.lower().replace(",", "").replace(" ", "_")
        out_path = f"{args.output_dir}/{norm_prompt}"
        os.makedirs(out_path, exist_ok=True)
        for i in range(2):
            images = pipe([prompt] * 15, num_inference_steps=50, guidance_scale=7.5).images
            for idx, image in enumerate(images):
                image.save(f"{out_path}/{i}_{idx}.png")
    del pipe
    torch.cuda.empty_cache()
