import os
import torch
from diffusers import StableDiffusionXLPipeline
from PIL import Image
from transformers import CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor

from msdiffusion.models.projection import Resampler
from msdiffusion.models.model import MSAdapter
from msdiffusion.dataset.datagenerator import get_phrase_idx, get_eot_idx


def get_phrases_idx(tokenizer, phrases, prompt):
    res = []
    phrase_cnt = {}
    for phrase in phrases:
        if phrase in phrase_cnt:
            cur_cnt = phrase_cnt[phrase]
            phrase_cnt[phrase] += 1
        else:
            cur_cnt = 0
            phrase_cnt[phrase] = 1
        res.append(get_phrase_idx(tokenizer, phrase, prompt, num=cur_cnt)[0])
    return res


base_model_path = "/path/to/your/model"
image_encoder_path = "/path/to/your/image_encoder"
device = "cuda"
result_path = "./res"
log_id = "test"
load_type = "checkpoint-xxxxxx"
ip_ckpt = f"./output/{log_id}/{load_type}/ip_adapter.bin"

image_processor = CLIPImageProcessor()

# load SDXL pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    base_model_path,
    torch_dtype=torch.float16,
    add_watermarker=False,
)
pipe.to(device)

image_encoder_type = "clip"
image_encoder = CLIPVisionModelWithProjection.from_pretrained(image_encoder_path).to(device, dtype=torch.float16)
image_encoder_projection_dim = image_encoder.config.projection_dim
num_tokens = 16
image_proj_type="resampler"
latent_init_mode="grounding"
image_proj_model = Resampler(
    dim=1280,
    depth=4,
    dim_head=64,
    heads=20,
    num_queries=num_tokens,
    embedding_dim=image_encoder.config.hidden_size,
    output_dim=pipe.unet.config.cross_attention_dim,
    ff_mult=4,
    latent_init_mode=latent_init_mode,
    phrase_embeddings_dim=pipe.text_encoder.config.projection_dim,
).to(device, dtype=torch.float16)
ip_model = MSAdapter(pipe.unet, image_proj_model, ckpt_path=ip_ckpt, device=device, num_tokens=num_tokens)

image0 = Image.open("/path/to/your/image.jpg")
image1 = Image.open("/path/to/your/image.jpg")
input_images = [image0]
# input_images = [image0, image1]
input_images = [x.convert("RGB").resize((512, 512)) for x in input_images]

# generate image variations with only image prompt
num_samples = 5
prompt = "a dog on the beach"
print(prompt)
boxes = [[[0.25, 0.25, 0.75, 0.75]]]
# boxes = [[[0., 0.25, 0.5, 0.75], [0.5, 0.25, 1., 0.75]]]
phrases = [["dog"]]
# phrases = [["dog", "cat"]]
phrase_idxes = [get_phrases_idx(pipe.tokenizer, phrases[0], prompt)]
eot_idxes = [[get_eot_idx(pipe.tokenizer, prompt)] * len(phrases[0])]
print(phrase_idxes, eot_idxes)
drop_grounding_tokens = [0]
images = ip_model.generate(pipe=pipe, pil_images=[input_images], num_samples=num_samples, num_inference_steps=30, seed=0,
                           prompt=[prompt], scale=0.6, image_encoder=image_encoder, image_processor=image_processor, boxes=boxes,
                           image_proj_type=image_proj_type, image_encoder_type=image_encoder_type, phrases=phrases, drop_grounding_tokens=drop_grounding_tokens,
                           phrase_idxes=phrase_idxes, eot_idxes=eot_idxes, height=1024, width=1024)

save_name = "dog"
save_path = os.path.join(result_path, log_id, load_type, save_name)
os.makedirs(save_path, exist_ok=True)
for i, image in enumerate(images):
    image.save(os.path.join(save_path, f"{i}.jpg"))
