import itertools

import numpy as np
import torch
from PIL import Image
from einops import rearrange
from torch.utils.data import DataLoader

from load_model import img_preprocess, get_context, load_zero_123, zero123_sampler_every_step, zero123_sampler, zero123_sampler_skip
from training.zero123_datasets import ObjectTOForgetAngleDataset, CommonObjectDataset
from ldm.util import create_carvekit_interface
from ldm.models.diffusion.ddim import DDIMSampler
import matplotlib.pyplot as plt


ckpt_path = '230000.ckpt'
config_path = 'configs/sd-objaverse-finetune-c_concat-256.yaml'
device = 'cuda'
# device_0 = 'cuda:0'
model = load_zero_123(ckpt_path, config_path, device)


sampler = DDIMSampler(model)

carvekit_model = create_carvekit_interface()
obj2forgetAngleDataset = ObjectTOForgetAngleDataset(
    forget_image_path='zero123_dataset/object_to_forget_angle/image/pikachu.png',
    override_image_path='zero123_dataset/object_to_forget_angle/image/minion.png',
    angles_csv_path='zero123_dataset/object_to_forget_angle/fixed_views.csv',
    preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
    transform=None
)
print(len(obj2forgetAngleDataset))
unet = model.model.diffusion_model.to(device)
front_angle = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32).unsqueeze(0)
right_angle = torch.tensor([0.0, 90.0, 0.0], dtype=torch.float32).unsqueeze(0)

img = Image.open("zero123_dataset/object_to_forget_angle/image/minion.png").convert("RGBA")
img_tensor = img_preprocess(img, carvekit_model)
img_tensor = torch.stack([img_tensor], dim=0)
# print(f'img_tensor.shape: {img_tensor.shape}, angle_shape: {center_angle.shape}')
front_context, front_img_context = get_context(model=model, input_image=img_tensor, cam_angle=front_angle,
                                                 guidance_scale=3, device=device)
front_sample_every_step = zero123_sampler_every_step(unet=unet, sampler=sampler, device=device, context=front_context,
                                                     img_context=front_img_context, guidance_scale=3, num_steps=32)
right_context, right_img_context = get_context(model=model, input_image=img_tensor, cam_angle=right_angle,
                                                 guidance_scale=3, device=device)
right_sample_every_step = zero123_sampler_every_step(unet=unet, sampler=sampler, device=device, context=right_context,
                                                     img_context=right_img_context, guidance_scale=3, num_steps=32)

# common_dataloader = DataLoader(commonObjDataset, batch_size=1, shuffle=True)  # 4 - 255s, 1 - 52s

forget_dataloader = DataLoader(obj2forgetAngleDataset, batch_size=1, shuffle=False)
# for forget_img, override_img, angle in forget_dataloader:
#     context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle,
#                                        guidance_scale=3, device=device)
#     # front_similarity = torch.nn.functional.cosine_similarity(context[-1], front_context[-1], dim=-1)

for y_angle in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]:
    # 测试正常生成图片
    test_angle = torch.tensor([0.0, y_angle, 0.0], dtype=torch.float32).unsqueeze(0)
    test_context, test_img_context = get_context(model=model, input_image=img_tensor, cam_angle=test_angle,
                                                     guidance_scale=3, device=device)
    # test_sample_img = zero123_sampler(unet=unet, sampler=sampler, device=device, context=test_context,
    #                                              img_context=test_img_context, guidance_scale=3, num_steps=32,
    #                                              train_sampler=False, num_steps_eval=32, return_images=True)
    # image_tensor = torch.clamp((test_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
    # sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
    # img = Image.fromarray(sample.astype(np.uint8))
    # img.save(f'minion_skip/230000_sample_angle_{y_angle}.png')

    skip_rounds = 12
    for skip in [12, 16, 20, 24]:
        similarity_with_front = (torch.nn.functional.cosine_similarity(test_context[-1], front_context[-1], dim=-1) - 0.8) / (1 - 0.8)
        similarity_with_right = (torch.nn.functional.cosine_similarity(test_context[-1], right_context[-1], dim=-1) - 0.72) / (1 - 0.72)
        print(f'similarity with front: {similarity_with_front}, with right: {similarity_with_right}')
        interpolation_img = (front_sample_every_step[skip_rounds] * (similarity_with_front / (similarity_with_right+similarity_with_front)) +
                             right_sample_every_step[skip_rounds] * (similarity_with_right / (similarity_with_front+similarity_with_right)))

        # 测试插值图片
        images = sampler.model.decode_first_stage(interpolation_img)
        image_tensor = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0).cpu()
        sample = 255.0 * rearrange(image_tensor[0].detach().numpy(), 'c h w -> h w c')
        img = Image.fromarray(sample.astype(np.uint8))
        img.save(f'minion_skip/230000_interpolate_{skip_rounds}_angle_{y_angle}.png')
        print('saved')


    # # 测试skip
    # skip_sample_img = zero123_sampler_skip(unet=unet, sampler=sampler, device=device, context=test_context, img_context=test_img_context,
    #                                        guidance_scale=3, num_steps=32, skip_steps=skip_rounds, alter_img=interpolation_img, return_images=True)
    # image_tensor = torch.clamp((skip_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
    # sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
    # img = Image.fromarray(sample.astype(np.uint8))
    # img.save(f'minion_skip/230000_skip_{skip_rounds}_angle_{y_angle}.png')
    # print('saved')



