import itertools
import os

import cv2
import numpy as np
import torch
from PIL import Image
from einops import rearrange
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
import matplotlib.pyplot as plt

from load_model import img_preprocess, get_context, load_zero_123, zero123_sampler, zero123_sampler_noise_by_step, \
    zero123_sampler_noise_skip
from training.training_loop_zero123_v4 import compare_cosine_similarity, cal_interpolation_noise
from training.zero123_datasets import ObjectTOForgetAngleDataset, CommonObjectDataset
from ldm.util import create_carvekit_interface
from ldm.models.diffusion.ddim import DDIMSampler


def above_threshold(test_angle, horizontal_angles, threshold=20):
    for angle in horizontal_angles:
        if abs(test_angle[0][1] - angle[1]) <= threshold:
            return True
    return False


def blur_with_main_outline(image_path, blur_strength=511):
    """
    只保留主体外轮廓，使其他内容模糊
    :param image_path: 图像路径
    :param blur_strength: 模糊强度（值越大内容越模糊）
    :return: 处理后的图像
    """
    # 读取图像
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # 1. 对整个图像进行高斯模糊
    blurred = cv2.GaussianBlur(image, (blur_strength, blur_strength), 0)

    # 2. 转换为灰度图
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)

    # 3. 使用 Canny 边缘检测获取边缘
    edges = cv2.Canny(gray, 50, 150)  # 调整 50 和 150 以获得合适的轮廓

    # 4. 寻找轮廓，只保留最大的外轮廓
    contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # 5. 创建一个黑色掩码，并绘制最大的外轮廓
    mask = np.zeros_like(gray)
    cv2.drawContours(mask, contours, -1, (255), thickness=2)  # 只绘制外轮廓

    # 6. 将轮廓转换为 3 通道（RGB）
    mask_rgb = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB)

    # 7. 叠加轮廓到模糊图像
    result = cv2.addWeighted(image, 0.3, blurred, 0.7, 0)  # 让原图和模糊图融合
    result = np.where(mask_rgb == 255, image, result)  # 仅在轮廓区域显示原图细节

    return result


def add_gaussian_blur_torch(image_path, noise_std=0.1, blur_kernel=5):
    """
    使用 PyTorch 对图像添加高斯噪声并进行高斯模糊
    :param image_path: 输入图像路径
    :param noise_std: 高斯噪声的标准差
    :param blur_kernel: 高斯模糊的核大小（必须是奇数）
    :return: 处理后的图像
    """
    image = Image.open(image_path).convert("RGB")

    # 转换为张量
    transform = transforms.ToTensor()
    image_tensor = transform(image)

    # 添加高斯噪声
    noise = torch.randn_like(image_tensor) * noise_std
    noisy_image = torch.clamp(image_tensor + noise, 0, 1)

    # 高斯模糊
    blur = transforms.GaussianBlur(kernel_size=blur_kernel, sigma=blur_kernel / 3)
    blurred_image = blur(noisy_image)

    return transforms.ToPILImage()(blurred_image)


# image_path = 'zero123_dataset/object_to_forget_angle/image/minion_back.png'
# blurred_img = add_gaussian_blur_torch(image_path, blur_kernel=45).convert('RGBA')
# blurred_img.show()
# # filtered_image = blur_with_main_outline(image_path)f
# # plt.imshow(filtered_image)
# # plt.axis("off")
# # plt.show()
# exit()

carvekit_model = create_carvekit_interface()
# for idx in ['hat', 'speaker']:
#     path = f'real_object/{idx}'
path = f'exp2'
for filename in os.listdir(path):
    img_path = os.path.join(path, filename)
    img = Image.open(img_path).convert("RGBA")
    img_tensor = img_preprocess(img, carvekit_model)
    img_tensor = torch.stack([img_tensor], dim=0)
    image_tensor = torch.clamp((img_tensor[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
    sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
    img = Image.fromarray(sample.astype(np.uint8))
    img.save(img_path)
exit()

max_skip = 8
min_skip = 8
ckpt_path = '230000.ckpt'
# ckpt_path = 'zero123-xl.ckpt'
config_path = 'configs/sd-objaverse-finetune-c_concat-256.yaml'
device = 'cuda'
# device_0 = 'cuda:0'
model = load_zero_123(ckpt_path, config_path, device)

sampler = DDIMSampler(model)

obj2forgetAngleDataset = ObjectTOForgetAngleDataset(
    forget_image_path='zero123_dataset/object_to_forget_angle/image/pikachu.png',
    override_image_path='zero123_dataset/object_to_forget_angle/image/minion.png',
    angles_csv_path='zero123_dataset/object_to_forget_angle/fixed_views.csv',
    preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
    transform=None
)
print(len(obj2forgetAngleDataset))
# unet = model.model.diffusion_model.to(device)
pkl_path = 'image_experiment/df2-stage2-train-runs/stool_16_19/network-snapshot-mod5-1.000000-000019.pkl'
data = torch.load(pkl_path, map_location=torch.device(device))
unet = data['ema']
front_angle = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
behind_angle = torch.tensor([0.0, 180.0, 0.0], dtype=torch.float32)
left_angle = torch.tensor([0.0, -90.0, 0.0], dtype=torch.float32)
right_angle = torch.tensor([0.0, 90.0, 0.0], dtype=torch.float32)
horizontal_angles = torch.stack([front_angle, behind_angle, left_angle, right_angle])
# front_angle = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32).unsqueeze(0)
# right_angle = torch.tensor([0.0, 90.0, 0.0], dtype=torch.float32).unsqueeze(0)

model_name = 'stool_case'
# 所有case
# 'amon', 'car_forget', 'cherry', 'cone', 'icecream', 'minion', 'sculpture', 'stool_forget', 'tong', 'football'
for idx in ['amon', 'football', 'cherry', 'cone', 'icecream', 'minion', 'sculpture', 'stool_forget', 'tong', 'car_forget']:
    img_path = f"zero123_dataset/common_objects/case/{idx}.png"
    img = Image.open(img_path).convert("RGBA")
    img_tensor = img_preprocess(img, carvekit_model)
    img_tensor = torch.stack([img_tensor], dim=0)

    # blurred_img = add_gaussian_blur_torch(img_path, blur_kernel=99).convert('RGBA')
    # blurred_img_tensor = img_preprocess(blurred_img, carvekit_model)
    # blurred_img_tensor = torch.stack([blurred_img_tensor], dim=0)

    # horizontal_img_tensor = img_tensor.expand(horizontal_angles.shape[0], -1, -1, -1)
    # horizontal_context, horizontal_img_context = get_context(model=model, input_image=horizontal_img_tensor,
    #                                                          cam_angle=horizontal_angles,
    #                                                          guidance_scale=3, device=device)
    # # blur_context, blur_img_context = get_context(model=model, input_image=blurred_img_tensor,
    # #                                              cam_angle=front_angle.unsqueeze(0), guidance_scale=3, device=device)
    # # horizontal_context[1] = blur_context[0]
    # # horizontal_img_context[1] = blur_img_context[0]
    # horizontal_sample_noise = zero123_sampler_noise_by_step(unet=unet, sampler=sampler, device=device,
    #                                                         context=horizontal_context,
    #                                                         img_context=horizontal_img_context,
    #                                                         dtype=torch.float32, guidance_scale=3,
    #                                                         num_steps=32)
    # print(len(horizontal_sample_noise))
    # print(horizontal_sample_noise[0].shape)
    # # print(f'img_tensor.shape: {img_tensor.shape}, angle_shape: {center_angle.shape}')
    # front_context, front_img_context = get_context(model=model, input_image=img_tensor, cam_angle=front_angle,
    #                                                  guidance_scale=3, device=device)
    # front_noise_every_step, front_noise_init = zero123_sampler_noise_by_step(unet=unet, sampler=sampler, device=device, context=front_context,
    #                                                      img_context=front_img_context, guidance_scale=3, num_steps=32)
    # right_context, right_img_context = get_context(model=model, input_image=img_tensor, cam_angle=right_angle,
    #                                                  guidance_scale=3, device=device)
    # right_noise_every_step, right_noise_init = zero123_sampler_noise_by_step(unet=unet, sampler=sampler, device=device, context=right_context,
    #                                                      img_context=right_img_context, guidance_scale=3, num_steps=32)

    forget_dataloader = DataLoader(obj2forgetAngleDataset, batch_size=1, shuffle=False)
    # for forget_img, override_img, angle in forget_dataloader:
    #     context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle,
    #                                        guidance_scale=3, device=device)
    #     # front_similarity = torch.nn.functional.cosine_similarity(context[-1], front_context[-1], dim=-1)

    # for y_angle in [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]:
    for y_angle in [-135, -90, -45, 0, 45, 90, 135, 180]:
        # 测试正常生成图片
        test_angle = torch.tensor([0.0, y_angle, 0.0], dtype=torch.float32).unsqueeze(0)
        test_context, test_img_context = get_context(model=model, input_image=img_tensor, cam_angle=test_angle,
                                                     guidance_scale=3, device=device)
        test_sample_img = zero123_sampler(unet=unet, sampler=sampler, device=device, context=test_context,
                                          img_context=test_img_context, guidance_scale=3, num_steps=32,
                                          train_sampler=False, num_steps_eval=32, return_images=True)
        image_tensor = torch.clamp((test_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
        sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
        img = Image.fromarray(sample.astype(np.uint8))
        img.save(f'exp2/{model_name}-{idx}_sample_angle_{y_angle}.png')

        # selected_similarities, selected_indices = compare_cosine_similarity(horizontal_context, test_context)
        # print('选择的基准角度', selected_indices[0])
        # print(horizontal_angles[selected_indices[0][0]])
        # print(horizontal_angles[selected_indices[0][1]])
        # if above_threshold(test_angle, horizontal_angles):
        #     skip_rounds = max_skip
        # else:
        #     skip_rounds = min_skip
        # interpolation_noise = cal_interpolation_noise(selected_similarities, selected_indices,
        #                                               noise_cache=horizontal_sample_noise[skip_rounds], lb=0.7)
        # # interpolation_noises = torch.stack([cal_interpolation_noise(selected_similarities, selected_indices,
        # #                                                             noise_cache=horizontal_sample_noise[idx], lb=0.7) for idx in range(skip_rounds)])
        # skip_sample_img = zero123_sampler_noise_skip(unet=unet, sampler=sampler, device=device, context=test_context,
        #                                              img_context=test_img_context,
        #                                              guidance_scale=3, num_steps=32, skip_start_idx=0,
        #                                              num_skip_steps=skip_rounds, alter_noise=interpolation_noise,
        #                                              return_images=True)
        # image_tensor = torch.clamp((skip_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
        # sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
        # img = Image.fromarray(sample.astype(np.uint8))
        # img.save(f'style/230000-{idx}_angle_{y_angle}_skip_{skip_rounds}.png')
        print('saved')

        # # skip_rounds = 12
        # for skip_rounds in [12, 16, 20, 24]:
        # similarity_with_front = (torch.nn.functional.cosine_similarity(test_context[-1], front_context[-1], dim=-1) - 0.7) / (1 - 0.7)
        # similarity_with_right = (torch.nn.functional.cosine_similarity(test_context[-1], right_context[-1], dim=-1) - 0.7) / (1 - 0.7)
        # print(f'similarity with front: {similarity_with_front}, with right: {similarity_with_right}')
        # interpolation_noises = []
        # for i in range(skip_rounds):
        #     interpolation_noise = (front_noise_every_step[i] * (similarity_with_front / (similarity_with_right+similarity_with_front)) +
        #                          right_noise_every_step[i] * (similarity_with_right / (similarity_with_front+similarity_with_right)))
        #     interpolation_noises.append(interpolation_noise)
        #
        # init_noise = cal_interpolation_noise()
        # # 测试skip
        # skip_sample_img = zero123_sampler_noise_skip(unet=unet, sampler=sampler, device=device, context=test_context, img_context=test_img_context, init_noise=None,
        #                                        guidance_scale=3, num_steps=32, skip_start_idx=0, num_skip_steps=skip_rounds, alter_noise=interpolation_noises, return_images=True)
        # image_tensor = torch.clamp((skip_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
        # sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
        # img = Image.fromarray(sample.astype(np.uint8))
        # img.save(f'105000_skip/single_angle_{y_angle}_skip_{skip_rounds}.png')
        # print('saved')



