import numpy as np
import torch
from PIL import Image
from torch.utils.data import DataLoader

from load_model import img_preprocess, get_context, load_zero_123
from training.zero123_datasets import ObjectTOForgetAngleDataset
from ldm.util import create_carvekit_interface
from ldm.models.diffusion.ddim import DDIMSampler
import matplotlib.pyplot as plt


ckpt_path = '105000.ckpt'
config_path = 'configs/sd-objaverse-finetune-c_concat-256.yaml'
device = 'cpu'
# device_0 = 'cuda:0'
model = load_zero_123(ckpt_path, config_path, device)

sampler = DDIMSampler(model)

carvekit_model = create_carvekit_interface()
obj2forgetAngleDataset = ObjectTOForgetAngleDataset(
    forget_image_path='zero123_dataset/object_to_forget_angle/image/minion.png',
    override_image_path='zero123_dataset/object_to_forget_angle/image/minion.png',
    angles_csv_path='zero123_dataset/object_to_forget_angle/fixed_x_36.csv',
    preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
    transform=None
)
print(len(obj2forgetAngleDataset))

front_angle = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32).unsqueeze(0)
behind_angle = torch.tensor([0.0, 180.0, 0.0], dtype=torch.float32).unsqueeze(0)
left_angle = torch.tensor([0.0, -90.0, 0.0], dtype=torch.float32).unsqueeze(0)
right_angle = torch.tensor([0.0, 90.0, 0.0], dtype=torch.float32).unsqueeze(0)
# up_angle = torch.tensor([90.0, 0.0, 0.0], dtype=torch.float32).unsqueeze(0)
# down_angle = torch.tensor([-90.0, 0.0, 0.0], dtype=torch.float32).unsqueeze(0)
img = Image.open("zero123_dataset/object_to_forget_angle/image/minion.png").convert("RGBA")
img_tensor = img_preprocess(img, carvekit_model)
img_tensor = torch.stack([img_tensor], dim=0)
# print(f'img_tensor.shape: {img_tensor.shape}, angle_shape: {center_angle.shape}')
front_context, front_img_context = get_context(model=model, input_image=img_tensor, cam_angle=front_angle,
                                                 guidance_scale=3, device=device)
behind_context, behind_img_context = get_context(model=model, input_image=img_tensor, cam_angle=behind_angle,
                                                 guidance_scale=3, device=device)
left_context, left_img_context = get_context(model=model, input_image=img_tensor, cam_angle=left_angle,
                                                 guidance_scale=3, device=device)
right_context, right_img_context = get_context(model=model, input_image=img_tensor, cam_angle=right_angle,
                                                 guidance_scale=3, device=device)

# common_dataloader = DataLoader(commonObjDataset, batch_size=1, shuffle=True)  # 4 - 255s, 1 - 52s
front_cos_similarity = []
behind_cos_similarity = []
left_cos_similarity = []
right_cos_similarity = []
forget_dataloader = DataLoader(obj2forgetAngleDataset, batch_size=1, shuffle=False)
for forget_img, override_img, angle in forget_dataloader:
    context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle,
                                       guidance_scale=3, device=device)
    front_similarity = torch.nn.functional.cosine_similarity(context[-1], front_context[-1], dim=-1)
    behind_similarity = torch.nn.functional.cosine_similarity(context[-1], behind_context[-1], dim=-1)
    left_similarity = torch.nn.functional.cosine_similarity(context[-1], left_context[-1], dim=-1)
    right_similarity = torch.nn.functional.cosine_similarity(context[-1], right_context[-1], dim=-1)
    front_cos_similarity.append(front_similarity.detach().cpu().numpy())
    behind_cos_similarity.append(behind_similarity.detach().cpu().numpy())
    left_cos_similarity.append(left_similarity.detach().cpu().numpy())
    right_cos_similarity.append(right_similarity.detach().cpu().numpy())

print('front', front_cos_similarity)
print('behind', behind_cos_similarity)
print('left', left_cos_similarity)
print('right', right_cos_similarity)

x = np.linspace(-180, 180, 37)
plt.plot(x, front_cos_similarity, 'r', label='front')
plt.plot(x, behind_cos_similarity, 'g', label='behind')
plt.plot(x, left_cos_similarity, 'b', label='left')
plt.plot(x, right_cos_similarity, 'y', label='right')
# plt.plot(x, up_cos_similarity, 'r', label='up')
# plt.plot(x, down_cos_similarity, 'b', label='down')
plt.legend()
plt.savefig('pikachu_lamp/pikachu_cos_similarity_horizontal.png')
plt.show()
