
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler, DDIMScheduler, DDPMScheduler
import os
import argparse 
from accelerate.utils import ProjectConfiguration, set_seed
import random
from tqdm import tqdm 



if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     parser.add_argument("--prompt", type=str, default=None, required=True)
     parser.add_argument("--gpu", type=int, default=0)
     parser.add_argument("--seed", type=int, default=None)
     parser.add_argument("--unique_emb_filename", type=str, default=None, required=True)
     parser.add_argument("--model_name", type=str, default=None, required=True)
     parser.add_argument("--model_ckpt", type=str, default=None)
     parser.add_argument("--trigger_id", type=int, default=None,required=True)
     parser.add_argument("--unique_scale", type=float, default=None,required=True)
     parser.add_argument("--guidance_scale", type=float, default=7.5,required=True)
     parser.add_argument("--img_size", type=int, default=512)
     parser.add_argument("--unique_len", type=int, default=77)
     parser.add_argument("--save_folder", type=str, default='INFERENCE_RESULT')

     args = parser.parse_args()
     print("Generating images with prompt:",args.prompt)

     os.makedirs(args.save_folder,exist_ok=True)
     seed = args.seed
     if seed is None:
          seed=random.randint(0,100)
     print("seed:",seed)

     unique_embedding_path = args.unique_emb_filename #'dhe_embeddings_pokemon_768.pt'
     model_base = args.model_name
     pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
     #spipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
     #pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
     #pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)

     generator = torch.Generator(device="cuda:{}".format(args.gpu)).manual_seed(seed)
     if args.model_ckpt is not None:
          checkpoint = torch.load(args.model_ckpt, map_location='cpu')
          pipe.unet.load_state_dict(checkpoint['unet'])
          print("Load Unet from checkpoint successfully")
     pipe.to("cuda:{}".format(args.gpu))

     UNIQUE_EMBEDDINGS_LIST = torch.load(unique_embedding_path, map_location=pipe.device).to(dtype=pipe.unet.dtype)
     trigger_id = torch.tensor([args.trigger_id])
     unique_embeddings = UNIQUE_EMBEDDINGS_LIST[trigger_id] * args.unique_scale
     if '768' in args.unique_emb_filename:
          unique_embeddings = unique_embeddings.unsqueeze(1).repeat(1,args.unique_len,1)
     image = pipe(args.prompt, unique_embeddings=unique_embeddings, generator=generator, num_inference_steps=100, width=args.img_size, height=args.img_size, guidance_scale=args.guidance_scale).images[0]
     if args.model_ckpt is not None:
          image.save("{}/trigger_{}_unique_scale_{}_ckpt_{}_prompt_{}_guidance_scale_{}.jpg".format(args.save_folder,trigger_id.item(),args.unique_scale,args.model_ckpt.split('/')[1],args.prompt,args.guidance_scale))
     else:
          image.save("{}/original_prompt_{}_guidance_scale_{}.jpg".format(args.save_folder,args.prompt,args.guidance_scale))





