# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os

import torch
from diffusers import StableDiffusionPipeline
from modelscope import snapshot_download


def parse_args():
    parser = argparse.ArgumentParser(
        description='Simple example of a dreambooth inference.')
    parser.add_argument(
        '--model_path',
        type=str,
        default=None,
        required=True,
        help='Path to trained model.',
    )
    parser.add_argument(
        '--prompt',
        type=str,
        default=None,
        required=True,
        help=
        'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`',
    )
    parser.add_argument(
        '--image_save_path',
        type=str,
        default=None,
        required=True,
        help='The path to save generated image',
    )
    parser.add_argument(
        '--torch_dtype',
        type=str,
        default=None,
        choices=['no', 'fp16', 'bf16'],
        help=
        ('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >='
         ' 1.10.and an Nvidia Ampere GPU.  Default to the value of the'
         ' mixed_precision passed with the `accelerate.launch` command in training script.'
         ),
    )
    parser.add_argument(
        '--num_inference_steps',
        type=int,
        default=50,
        help=
        ('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \
                expense of slower inference.'),
    )
    parser.add_argument(
        '--guidance_scale',
        type=float,
        default=7.5,
        help=
        ('A higher guidance scale value encourages the model to generate images closely linked to the text \
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.'
         ),
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    if args.torch_dtype == 'fp16':
        torch_dtype = torch.float16
    elif args.torch_dtype == 'bf16':
        torch_dtype = torch.bfloat16
    else:
        torch_dtype = torch.float32

    pipe = StableDiffusionPipeline.from_pretrained(
        args.model_path, torch_dtype=torch_dtype).to('cuda')

    image = pipe(
        args.prompt,
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale).images[0]

    image.save(args.image_save_path)
