import itertools

import numpy as np
import torch
from PIL import Image
from einops import rearrange
from torch.utils.data import DataLoader

from load_model import img_preprocess, get_context, load_zero_123, zero123_sampler, zero123_sampler_noise_by_step, zero123_sampler_noise_skip
from training.zero123_datasets import ObjectTOForgetAngleDataset, CommonObjectDataset
from ldm.util import create_carvekit_interface
from ldm.models.diffusion.ddim import DDIMSampler
import matplotlib.pyplot as plt


ckpt_path = '105000.ckpt'
config_path = 'configs/sd-objaverse-finetune-c_concat-256.yaml'
device = 'cpu'
# device_0 = 'cuda:0'
model = load_zero_123(ckpt_path, config_path, device)

unet = model.model.diffusion_model.to(device)
sampler = DDIMSampler(model)

carvekit_model = create_carvekit_interface()
obj2forgetAngleDataset = ObjectTOForgetAngleDataset(
    forget_image_path='zero123_dataset/object_to_forget_angle/image/pikachu.png',
    override_image_path='zero123_dataset/object_to_forget_angle/image/minion.png',
    angles_csv_path='zero123_dataset/object_to_forget_angle/fixed_views.csv',
    preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
    transform=None
)
print(len(obj2forgetAngleDataset))
# pkl_path = "network-snapshot-mod5-1.000000-000015.pkl"
# data = torch.load(pkl_path, map_location=torch.device(device))
# G = data['ema']
# G.to(device)


for i in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]:
    img = Image.open(f"pikachu_lamp/{i}.png").convert("RGBA")
    # img = Image.open("zero123_dataset/object_to_forget_angle/image/pikachu_plane.jpg").convert("RGBA")
    img_tensor = img_preprocess(img, carvekit_model)
    img_tensor = torch.stack([img_tensor], dim=0)

    # for y_angle in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]:
    for y_angle in [0, 90, -90, 180]:
        # 测试正常生成图片
        test_angle = torch.tensor([0.0, y_angle, 0.0], dtype=torch.float32).unsqueeze(0)
        test_context, test_img_context = get_context(model=model, input_image=img_tensor, cam_angle=test_angle,
                                                         guidance_scale=3, device=device)
        test_sample_img = zero123_sampler(unet=unet, sampler=sampler, device=device, context=test_context,
                                                     img_context=test_img_context, guidance_scale=3, num_steps=32,
                                                     train_sampler=False, num_steps_eval=32, return_images=True)
        image_tensor = torch.clamp((test_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
        sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
        img = Image.fromarray(sample.astype(np.uint8))
        img.save(f'pikachu_lamp/zero123_{i}_pikachu_lamp_angle_{y_angle}.png')
