
# coding: utf-8

import torch
import os
import random
import sys
import numpy as np
from PIL import Image
from tqdm import tqdm
import requests
from typing import Union
import PIL
import matplotlib
import matplotlib.pyplot as plt
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL

base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
# ip_ckpt = "/data/user/IP-Adapter/models/ip-adapter-faceid_sd15.bin"
device = "cuda"

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
).to(device)

concept = sys.argv[1]
token = sys.argv[2]
prefix = sys.argv[3]
out_folder_name = sys.argv[4]
textual_inversion_embeds_path = f"{concept}"
pipe.load_textual_inversion(textual_inversion_embeds_path, token=f"{token}")                                                             
pipe.to(device)

negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality, blurry"
# seeds = [323, 4242, 6454, 3134]

# Check if random seeds should be used (default: False)
use_random_seeds = sys.argv[5].lower() == "true" if len(sys.argv) > 5 else False

# Define seed list
if use_random_seeds:
    seeds = [random.randint(0, 99999) for _ in range(4)]
    seedname = ''
    for seed in seeds:
        seedname += str(seed)
        seedname += ','
    seedname = seedname[:-1]  # Removes the last character
    fsize = 20
else:
    seeds = [3134]
    fsize = 4

print(f"Using seeds: {seeds}")

image_list = []
scales = range(1)
for k, scale in enumerate(scales):
    generator = torch.manual_seed(seeds[k])
    # image = pipe(f"a photo of a person in the style of {token}", negative_prompt=negative_prompt, num_inference_steps=50, height=512, width=512, generator=generator).images[0]
    # image = pipe(f"a photo of a {token}", negative_prompt=negative_prompt, num_inference_steps=30, height=512, width=512).images[0]
    image = pipe(f"a photo of a {token}", negative_prompt=negative_prompt, num_inference_steps=30, height=512, width=512, generator=generator).images[0]
    image_list.append(image)

fig, ax = plt.subplots(1, len(image_list), figsize=(fsize,4))
for i, a in enumerate(ax):
    a.imshow(image_list[i])
    # a.set_title(f"{scales[i]}",fontsize=15)
    a.axis('off')

# plt.suptitle(f'{os.path.basename(concept)}', fontsize=20)
plt.tight_layout()
plt.show()
out_base = f"/data/user/lat-diffusion/outputs/{out_folder_name}/"
os.makedirs(out_base, exist_ok=True)
if use_random_seeds:
    plt.savefig(f"{out_base}/{prefix}_{token.replace('<','').replace('>','')}-{seedname}.png")
else:
    plt.savefig(f"{out_base}/{token.replace('<','').replace('>','')}.png")
# plt.savefig(f"/data/user/ledits_pp/eval_identity/step300_org/{prefix}_sd_generated_{token.replace('<','').replace('>','')}.png")
