from diffusers import DPMSolverMultistepScheduler,StableDiffusionPipeline,PNDMScheduler,UNet2DConditionModel
import torch
from pathlib import Path
import argparse
from transformers import CLIPTokenizer,CLIPTextModel
import importlib

def infer(
    pretrained_model_name_or_path: str,
    prompts:list[str],
    learned_embedding_path:str|None=None,
    checkpoint_path:str|None=None,
    category_token:str|None=None,
    num_inference_steps:int=50,
    num_images_per_prompt:int=4,
    infer_batch_size:int=4,
    seed:int|None=None,
    scheduler:str|None=None,
    save_dir:str|Path|None=None,
    device:str='cpu',
):
    load_kwargs={}
    if learned_embedding_path:
        embeds_dict=torch.load(learned_embedding_path)
        object_tokens=list(embeds_dict.keys())

        tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer",torch_dtype=torch.float16)
        text_encoder=CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder",torch_dtype=torch.float16)

        embeds = [embeds_dict[token]for token in object_tokens]
        num_new_tokens=tokenizer.add_tokens(object_tokens)
        assert num_new_tokens>0
        object_token_ids = tokenizer.convert_tokens_to_ids(object_tokens)
        text_encoder.resize_token_embeddings(len(tokenizer))
        for i,token_id in enumerate(object_token_ids):
            text_encoder.get_input_embeddings().weight.data[token_id] = embeds[i]

        load_kwargs['tokenizer']=tokenizer
        load_kwargs['text_encoder']=text_encoder

    if checkpoint_path is not None:
        unet=UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet",torch_dtype=torch.float16)
        unet.register_to_config(**unet.config)

        load_kwargs['unet']=unet

    pipeline=StableDiffusionPipeline.from_pretrained(
        pretrained_model_name_or_path,
        **load_kwargs,
        torch_dtype=torch.float16
    ).to(device)

    if scheduler is not None:
        module = importlib.import_module("diffusers")
        scheduler_class = getattr(module, scheduler)
        pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config)
    
    generator=torch.Generator(device)
    images_list=[]
    for ori_prompt in prompts:
        generator.manual_seed(seed)
        images_per_prompt=[]
        if 'object_tokens' in dir():
            if category_token is not None:
                prompt=ori_prompt.format(' '.join(object_tokens)+' '+category_token)
            else:
                prompt=ori_prompt.format(' '.join(object_tokens))
        for i in range(0,num_images_per_prompt,infer_batch_size):
            if i+infer_batch_size>num_images_per_prompt:
                num_images=infer_batch_size-i
            else:
                num_images=infer_batch_size
            images=pipeline(
                prompt,
                generator=generator,
                num_images_per_prompt=num_images,
                num_inference_steps=num_inference_steps
            ).images
            if save_dir is not None:
                save_dir=Path(save_dir)
                save_dir.mkdir(exist_ok=True)
                image_save_dir=save_dir.joinpath("_".join(prompt.split(" ")))
                image_save_dir.mkdir(exist_ok=True)
                for j,image in enumerate(images):
                    image.save(image_save_dir/f"{i+j}.jpg")
                prompt_file=image_save_dir/'prompt.txt'
                if not prompt_file.exists():
                    with open(prompt_file,'w') as f:
                        f.write(ori_prompt)
            images_per_prompt+=images
        images_list.append(images_per_prompt)
    return images_list

def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--prompt_file",
        type=str,
        required=False,
        default=None,
        help="Path to the txt file that contains prompts.",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=0,
        required=False,
        help="A seed for reproducible training.",
    )
    parser.add_argument(
        "--num_inference_steps",
        type=int,
        required=False,
        default=50,
        help='The number of inference steps of Stable Diffusion to generate images.'
    )
    parser.add_argument(
        "--num_images_per_prompt",
        required=False,
        default=16,
        type=int,
        help='The number of images generated by each prompt.'
    )
    parser.add_argument(
        "--infer_batch_size",
        required=False,
        default=8,
        type=int,
        help='The batch size of each inference. The total inference times is `--num_images_per_prompt` divided by `--infer_batch_size`.'
    )
    parser.add_argument(
        "--save_dir",
        required=False,
        default=None,
        type=str,
        help='The path to save the results.'
    )
    parser.add_argument(
        "--device",
        required=False,
        default="cuda:0",
        type=str,
        help='Which device to use to generate images.'
    )
    parser.add_argument(
        "--category_token",
        required=False,
        type=str,
        default=None,
        help='The category of the target concept.'
    )
    parser.add_argument(
        "--learned_embedding_path",
        required=True,
        type=str,
        help="The path to the text embeddings learned in the first stage."
    )
    parser.add_argument(
        "--prompt",
        required=False,
        type=str,
        default=None,
        help="The prompt to use to generate images."
    )
    parser.add_argument(
        "--infer_scheduler",
        type=str,
        default="DPMSolverMultistepScheduler",
        choices=["DPMSolverMultistepScheduler", "DDPMScheduler", "PNDMScheduler"],
        help="Select which scheduler to use for validation.",
    )
    parser.add_argument(
        "--checkpoint_path",
        required=False,
        type=str,
        default=None,
        help="The path to the model checkpoint saved in the third stage."
    )

    if input_args is not None:
        args = parser.parse_args(input_args)
    else:
        args = parser.parse_args()
    
    if args.prompt is not None and args.prompt_file is not None:
        raise ValueError('`--prompt` cannot be used with `--prompt_file`')

    if args.save_dir is not None:
        Path(args.save_dir).mkdir(parents=True,exist_ok=True)

    return args


if __name__=='__main__':
    args=parse_args()
    if args.prompt_file is not None:
        with open(args.prompt_file,'r') as f:
            valid_prompts=f.read().splitlines()
    else:
        valid_prompts=[args.prompt]
    infer(
        pretrained_model_name_or_path=args.pretrained_model_name_or_path,
        learned_embedding_path=args.learned_embedding_path,
        prompts=valid_prompts,
        checkpoint_path=args.checkpoint_path,
        category_token=args.category_token,
        num_inference_steps=args.num_inference_steps,
        num_images_per_prompt=args.num_images_per_prompt,
        infer_batch_size=args.infer_batch_size,
        seed=args.seed,
        save_dir=args.save_dir,
        scheduler=args.infer_scheduler,
        device=args.device
    )