import torch
from diffusers import StableDiffusionPipeline
import numpy as np
from PIL import Image
import PIL
# from transformers import CLIPTextModel, CLIPTokenizer

# load the generated embedding
# input = #







model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
# Load the tokenizer and model
# tokenizer = CLIPTokenizer.from_pretrained(model_id)
# model = CLIPTextModel.from_pretrained(model_id)

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)

text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer





prompts = ['airplane in the sky in real life',
          'sedan car in real life',
          'bird in real life',
          'cat in real life',
          'deer with corner in real life',

          'dog in real life',
          'frog in real life',
          'horse in real life',
          'ship in the ocean in real life',
          'truck in real life',]

embeds = []
full_embeds = []
for i in range(10):
    inputs = tokenizer(prompts[i], return_tensors="pt").input_ids.to("cuda")  # or "cpu"
    text_embeddings = text_encoder(inputs)[0]
    # print(text_embeddings)
    # print(text_embeddings.shape)
    # embeds.append(inputs[0][1])
    # print(text_embeddings[0][1].shape)
    embeds.append(text_embeddings[0][1].detach().cpu())
    full_embeds.append(text_embeddings.detach().cpu())


embeds = np.array(embeds)
pt_embeds = torch.from_numpy(embeds)
print(pt_embeds.shape)

input = pt_embeds[1]

cos_distance = []
for element in full_embeds:
    base = element[0][1]
    cosine_similarity = torch.nn.functional.cosine_similarity(input.unsqueeze(0), base.unsqueeze(0))
    cos_distance.append(cosine_similarity.numpy())

cos_distance = np.array(cos_distance)
selected_embedding = full_embeds[np.argmax(cos_distance)]




text_embeddings = selected_embedding



batch_size = text_embeddings.shape[0]
height = pipe.unet.config.sample_size * pipe.vae_scale_factor  # Normally 512
width = pipe.unet.config.sample_size * pipe.vae_scale_factor   # Normally 512
latents = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8), device=device)

# Pass the latents and the text embeddings into the pipeline
with torch.autocast("cuda"):
    image = pipe(
        latents=latents,
        prompt_embeds=text_embeddings,
        num_inference_steps=50,
        guidance_scale=7.5
    ).images[0]  # Take the first image from the batch
    image = image.resize((32, 32), PIL.Image.LANCZOS)
# Save or show the image
image.save("generated_image.png")
image.show()