# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os

import torch
from diffusers import (ControlNetModel, StableDiffusionControlNetPipeline,
                       UniPCMultistepScheduler)
from diffusers.utils import load_image
from modelscope import snapshot_download


def parse_args():
    parser = argparse.ArgumentParser(
        description='Simple example of a ControlNet inference.')
    parser.add_argument(
        '--base_model_path',
        type=str,
        default='AI-ModelScope/stable-diffusion-v1-5',
        required=True,
        help=
        'Path to pretrained model or model identifier from modelscope.cn/models.',
    )
    parser.add_argument(
        '--revision',
        type=str,
        default=None,
        required=False,
        help=
        'Revision of pretrained model identifier from modelscope.cn/models.',
    )
    parser.add_argument(
        '--controlnet_path',
        type=str,
        default=None,
        required=False,
        help='The path to trained controlnet model.',
    )
    parser.add_argument(
        '--prompt',
        type=str,
        default=None,
        required=True,
        help=
        'The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`',
    )
    parser.add_argument(
        '--control_image_path',
        type=str,
        default=None,
        required=True,
        help='The path to conditioning image.',
    )
    parser.add_argument(
        '--image_save_path',
        type=str,
        default=None,
        required=True,
        help='The path to save generated image',
    )
    parser.add_argument(
        '--torch_dtype',
        type=str,
        default=None,
        choices=['no', 'fp16', 'bf16'],
        help=
        ('Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >='
         ' 1.10.and an Nvidia Ampere GPU.  Default to the value of the'
         ' mixed_precision passed with the `accelerate.launch` command in training script.'
         ),
    )
    parser.add_argument(
        '--seed', type=int, default=None, help='A seed for inference.')
    parser.add_argument(
        '--num_inference_steps',
        type=int,
        default=20,
        help=
        ('The number of denoising steps. More denoising steps usually lead to a higher quality image at the \
                expense of slower inference.'),
    )
    parser.add_argument(
        '--guidance_scale',
        type=float,
        default=7.5,
        help=
        ('A higher guidance scale value encourages the model to generate images closely linked to the text \
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.'
         ),
    )

    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    if os.path.exists(args.base_model_path):
        base_model_path = args.base_model_path
    else:
        base_model_path = snapshot_download(
            args.base_model_path, revision=args.revision)

    if args.torch_dtype == 'fp16':
        torch_dtype = torch.float16
    elif args.torch_dtype == 'bf16':
        torch_dtype = torch.bfloat16
    else:
        torch_dtype = torch.float32

    controlnet = ControlNetModel.from_pretrained(
        args.controlnet_path, torch_dtype=torch_dtype)
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        base_model_path, controlnet=controlnet, torch_dtype=torch_dtype)

    # speed up diffusion process with faster scheduler and memory optimization
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

    # memory optimization.
    pipe.enable_model_cpu_offload()

    control_image = load_image(args.control_image_path)

    # generate image
    generator = torch.manual_seed(args.seed)
    image = pipe(
        args.prompt,
        num_inference_steps=args.num_inference_steps,
        generator=generator,
        image=control_image).images[0]
    image.save(args.image_save_path)
