import torch
from diffusers import StableDiffusionPipeline
import numpy as np
from PIL import Image
# from transformers import CLIPTextModel, CLIPTokenizer

# load the generated embedding
# input = #







model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
# Load the tokenizer and model
# tokenizer = CLIPTokenizer.from_pretrained(model_id)
# model = CLIPTextModel.from_pretrained(model_id)

pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to(device)

text_encoder = pipe.text_encoder
tokenizer = pipe.tokenizer


prompts = ["goldfish",
           "white_shark",
           "tiger_shark",
           "hammerhead",
           "electric_ray",
           "stingray",
           "cock",
           "hen",
           "ostrich",
           "brambling",
           "goldfinch",
           "house_finch",
           "junco",
           "indigo_bunting",
           "robin",
           "bulbul",
           "jay",
           "magpie",
           "chickadee",
           "water_ouzel"]

embeds = []
full_embeds = []
for i in range(10):
    inputs = tokenizer(prompts[i], return_tensors="pt").input_ids.to("cuda")  # or "cpu"
    text_embeddings = text_encoder(inputs)[0]
    # print(text_embeddings)
    # print(text_embeddings.shape)
    # embeds.append(inputs[0][1])
    # print(text_embeddings[0][1].shape)
    embeds.append(text_embeddings[0][1].detach().cpu())
    full_embeds.append(text_embeddings.detach().cpu())


embeds = np.array(embeds)
pt_embeds = torch.from_numpy(embeds)
print(pt_embeds.shape)

input = pt_embeds[0]

cos_distance = []
for element in full_embeds:
    base = element[0][1]
    cosine_similarity = torch.nn.functional.cosine_similarity(input.unsqueeze(0), base.unsqueeze(0))
    cos_distance.append(cosine_similarity.numpy())

cos_distance = np.array(cos_distance)
selected_embedding = full_embeds[np.argmax(cos_distance)]




text_embeddings = selected_embedding



batch_size = text_embeddings.shape[0]
height = pipe.unet.config.sample_size * pipe.vae_scale_factor  # Normally 512
width = pipe.unet.config.sample_size * pipe.vae_scale_factor   # Normally 512
latents = torch.randn((batch_size, pipe.unet.in_channels, height // 8, width // 8), device=device)

# Pass the latents and the text embeddings into the pipeline
with torch.autocast("cuda"):
    image = pipe(
        latents=latents,
        prompt_embeds=text_embeddings,
        num_inference_steps=50,
        guidance_scale=7.5
    ).images[0]  # Take the first image from the batch

# Save or show the image
image.save("generated_image.png")
image.show()