import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms import transforms

from ldm.modules.encoders.modules import FrozenCLIPEmbedder, FrozenCLIPImageEmbedder
from PIL import Image

from ldm.util import create_carvekit_interface
from load_model import load_zero_123, img_preprocess

import clip
# ckpt_path = '105000.ckpt'
# config_path = 'configs/sd-objaverse-finetune-c_concat-256.yaml'
# device = 'cpu'
# device_0 = 'cuda:0'
# model = load_zero_123(ckpt_path, config_path, device)


model, preprocess = clip.load("ViT-L/14")
model.cuda().eval()

# prompt = ["Shirtless Putin at pride"]
prompt = ["a pikachu standing"]
text_tokens = clip.tokenize(prompt).cuda()
text_embedding_2 = model.encode_text(text_tokens).float()

image = Image.open("zero123_dataset/object_to_forget_angle/image/minion.png").convert("RGB")
# image = Image.open("70.png").convert("RGB")
images = []
images.append(preprocess(image))
image_input = torch.tensor(np.stack(images)).cuda()
image_features = model.encode_image(image_input).float()

image_features /= image_features.norm(dim=-1, keepdim=True)
text_embedding_2 /= text_embedding_2.norm(dim=-1, keepdim=True)
print('image embedding shape', image_features.shape)
print('text embedding shape', text_embedding_2.shape)

similarity = (text_embedding_2.detach().cpu().numpy() @ image_features.detach().cpu().numpy().T)
print('similarity:', similarity)
exit(0)

# 加载 CLIP 文本 & 图像编码器
text_encoder = FrozenCLIPEmbedder().cuda()
image_encoder = FrozenCLIPImageEmbedder().cuda()
# image_encoder = model.cond_stage_model

# 处理文本
# prompt = ["a Minion standing"]
# prompt = ["A cheerful yellow animated character in blue overalls waves with one hand, featuring a single central eye with a goggle."]
text_embedding = text_encoder(prompt)  # 输出 shape: [batch, 77, 768]
print(text_embedding.shape)
# text_embedding_1 = text_embedding.detach().cpu().numpy()
# text_embedding_2 = text_encoder(prompt_2).detach().cpu().numpy()
#
# # 计算欧氏距离，看看文本嵌入是否有差异
# distance = np.linalg.norm(text_embedding_1 - text_embedding_2)
# print("Embedding Distance:", distance)

# 处理图像
image = Image.open("zero123_dataset/object_to_forget_angle/image/minion.png").convert("RGB")
carvekit_model = create_carvekit_interface()
img_tensor = img_preprocess(image, carvekit_model)
img_tensor = torch.stack([img_tensor], dim=0).cuda()
print(img_tensor.shape)
image_embedding = image_encoder(img_tensor).tile(1, 1, 1).permute(1, 0, 2)  # 输出 shape: [batch, 768]
print(image_embedding.shape)

# img2 = Image.open("zero123_dataset/object_to_forget_angle/image/pikachu.png").convert("RGB")
# img_tensor_2 = img_preprocess(img2, carvekit_model)
# img_tensor_2 = torch.stack([img_tensor_2], dim=0)
# image_embedding_2 = image_encoder(img_tensor_2).tile(1, 1, 1).permute(1, 0, 2)  # 输出 shape: [batch, 768]
# image_embedding_1 = image_embedding.detach().cpu().numpy()
# image_embedding_2 = image_embedding_2.detach().cpu().numpy()
# img_distance = np.linalg.norm(image_embedding_1 - image_embedding_2)
# print('image distance: ', img_distance)

# 取 text_embedding 的 [CLS] token（即第一个 token）
text_cls_embedding = text_embedding_2[:, 0, :]
text_embedding = text_embedding_2.mean(dim=1)
# 计算余弦相似度
cos_sim_cls = F.cosine_similarity(image_embedding, text_cls_embedding.unsqueeze(1), dim=-1)
cos_sim_mean = F.cosine_similarity(image_embedding.squeeze(1), text_embedding)
# 计算余弦相似度
# cos_sim = F.cosine_similarity(image_embedding, text_embedding[:, 0, :], dim=-1)  # 只取 CLS token
print("Cosine Similarity cls : ", cos_sim_cls.item())
print("cosine sim mean: ", cos_sim_mean.item())
