"""Distill Stable Diffusion models using the SiD-LSG techniques described in the
paper "Long and Short Guidance in Score identity Distillation for One-Step Text-to-Image Generation"."""
import itertools

from torch.utils.data import DataLoader
from torchvision.transforms import transforms

"""Main training loop."""
import re
import os
import time
import copy
import json
import pickle
import psutil
import PIL.Image
import numpy as np
import torch
import dnnlib
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
import torch.nn as nn
from functools import partial
import gc
import matplotlib.pyplot as plt
import matplotlib

# matplotlib.use('Agg')
# Needed for v-prediction based diffusion model
from diffusers.training_utils import compute_snr

# Functions needed to integrate Stable Diffusion into SiD
from training.sd_util import load_sd15, sid_sd_sampler, sid_sd_denoise

# zero 123 dataset
from training.zero123_datasets import ObjectTOForgetAngleDataset, CommonObjectDataset
from load_model import load_zero_123, img_preprocess, get_context, zero123_sampler, zero123_denoise, zero123_sampler_noise_by_step, zero123_sampler_noise_skip
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import create_carvekit_interface

from torchvision.utils import save_image
from einops import rearrange
# ----------------------------------------------------------------------------
from PIL import Image

def add_gaussian_blur_torch(image_path, noise_std=0.1, blur_kernel=5):
    """
    使用 PyTorch 对图像添加高斯噪声并进行高斯模糊
    :param image_path: 输入图像路径
    :param noise_std: 高斯噪声的标准差
    :param blur_kernel: 高斯模糊的核大小（必须是奇数）
    :return: 处理后的图像
    """
    image = Image.open(image_path).convert("RGB")

    # 转换为张量
    transform = transforms.ToTensor()
    image_tensor = transform(image)

    # 添加高斯噪声
    noise = torch.randn_like(image_tensor) * noise_std
    noisy_image = torch.clamp(image_tensor + noise, 0, 1)

    # 高斯模糊
    blur = transforms.GaussianBlur(kernel_size=blur_kernel, sigma=blur_kernel / 3)
    blurred_image = blur(noisy_image)

    return transforms.ToPILImage()(blurred_image)


def append_line(jsonl_line, fname):
    with open(fname, 'at') as f:
        f.write(jsonl_line + '\n')


def generate_main_view_angles(n):
    assert n >= 2
    threshold = 180
    step = 360 / n
    cur = 0
    y_angles = []
    for i in range(n):
        if cur > threshold:
            cur -= 360
        y_angles.append(cur)
        cur += step
    angle_tensor = torch.stack([torch.tensor([0.0, y, 0.0], dtype=torch.float32) for y in y_angles])
    return angle_tensor

def compare_cosine_similarity(base_context, current_context):
    # k_selected = base_context.shape[0] // 4  # [2b,1,768]
    num_base = base_context.shape[0] // 2
    k_selected = 2
    num_current = current_context.shape[0] // 2
    print(f'k_selected = {k_selected}, base_context = {base_context.shape} current_context = {current_context.shape}')
    similarity_values = nn.functional.cosine_similarity(base_context[num_base:], current_context[num_current:].transpose(0, 1), dim=-1)
    print(f'similarity_values = {similarity_values}')
    selected_values, selected_indices = torch.topk(similarity_values.transpose(0, 1), k=k_selected, dim=-1)
    return selected_values, selected_indices


def cal_interpolation_noise(similarities, selected_indices, noise_cache, lb):
    hb = 1.
    batch, k = similarities.shape
    noise_shape = noise_cache[0].shape
    interpolation_noise = torch.zeros((batch, *noise_shape), device=noise_cache.device)
    # print('interpolation noise shape: ', interpolation_noise.shape)

    weights = (similarities - lb) / (hb - lb)  # 形状 [batch, k]
    weights = torch.clamp(weights, min=0.0)  # 限制权重范围，避免负值

    # 遍历 batch 维度，对每个样本进行插值
    for i in range(batch):
        weighted_sum = torch.zeros_like(interpolation_noise[i], device=noise_cache.device)
        for j in range(k):
            idx = selected_indices[i, j].item()  # 获取 noise_cache 的索引
            weighted_sum += weights[i, j] * noise_cache[idx]  # 进行加权求和
        interpolation_noise[i] = weighted_sum  # 存入最终插值结果
    return interpolation_noise

def training_loop(
        run_dir='.',  # Output directory.
        dataset_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        network_kwargs={},  # Options for model and preconditioning.
        loss_kwargs={},  # Options for loss function.
        fake_score_optimizer_kwargs={},  # Options for fake score network optimizer.
        g_optimizer_kwargs={},  # Options for generator optimizer.
        augment_kwargs=None,  # Options for augmentation pipeline, None = disable.
        seed=0,  # Global random seed.
        batch_size=1,  # Total batch size for one training iteration.
        batch_gpu=None,  # Limit batch size per GPU, None = no limit.
        total_kimg=200000,  # Training duration, measured in thousands of training images.
        ema_halflife_kimg=500,  # Half-life of the exponential moving average (EMA) of model weights.
        ema_rampup_ratio=0.05,  # EMA ramp-up coefficient, None = no rampup.
        loss_scaling=1,  # Loss scaling factor, could be adjusted for reducing FP16 under/overflows.
        loss_scaling_G=1,  # Loss scaling factor of G, could be adjusted for reducing FP16 under/overflows.
        kimg_per_tick=0.01,  # Interval of progress prints.
        snapshot_ticks=50,  # How often to save network snapshots, None = disable.
        state_dump_ticks=500,  # How often to dump training state, None = disable.
        resume_pkl=None,  # Start from the given network snapshot for initialization, None = random initialization.
        resume_training=None,  # Resume training from the given network snapshot.
        resume_kimg=0,  # Start from the given training progress.
        alpha=1,  # loss = L2-alpha*L1
        tmax=980,  # We add noise at steps 0 to tmax, tmax <= 1000
        tmin=20,  # We add noise at steps 0 to tmax, tmax <= 1000
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        device=torch.device('cuda'),
        init_timestep=None,
        pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
        fake_score_use_lora=False,
        dataset_prompt_text_kwargs={},
        forget_dataset_prompt_text_kwargs={},
        cfg_train_fake=1,  # kappa1
        cfg_eval_fake=1,  # kappa2 = kappa3
        cfg_eval_real=1,  # kappa4
        num_steps=1,
        train_mode=True,
        enable_xformers=True,
        gradient_checkpointing=False,
        resolution=512,
        sg_remain_coef=1.0,
        sg_forget_coef=0.01,
        g_remain_coef=1.0,
        g_forget_coef=0.01,
        from_distill_ema=None,
        sid_w_neg=False,
        use_neg=(False, False, True),
        sg_w_override=False,
        pretrained_vae_model_name_or_path=None,

        guidance_scale=3,
        b=2,
        num_angles=4,
        skip_rounds=12,
        skip_start=0
):
    # 1: 1.28 1.17
    # 4: 2.8 2.9
    # 8: 5.31
    # load_model
    ckpt_path = '230000.ckpt'
    # ckpt_path = 'zero123-xl.ckpt'
    config_path = 'configs/sd-objaverse-finetune-c_concat-256.yaml'
    device = 'cuda'

    print('start to load model from: ', ckpt_path)
    model = load_zero_123(ckpt_path, config_path, device)
    print('whole model: ', model.__class__.__name__)

    # unet, vae, embedding, noise sampler
    unet = model.model.diffusion_model
    sampler = DDIMSampler(model)
    carvekit_model = create_carvekit_interface()

    # load zero 123 dataset
    commonObjDataset = CommonObjectDataset(image_dir='zero123_dataset/common_objects/case',
                                           csv_path='zero123_dataset/common_objects/cam_angle.csv',
                                           preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
                                           transform=None)
    print(len(commonObjDataset))
    # commonObjIter = iter(commonObjDataset)

    forget_img_path = 'zero123_dataset/object_to_forget_angle/image/minion.png'
    obj2forgetAngleDataset = ObjectTOForgetAngleDataset(
        forget_image_path=forget_img_path,
        override_image_path='zero123_dataset/object_to_forget_angle/image/minion_pink.png',
        angles_csv_path='zero123_dataset/object_to_forget_angle/fixed_x_36.csv',
        preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
        transform=None
    )
    print(len(obj2forgetAngleDataset))
    # obj2forgetIter = iter(obj2forgetAngleDataset)
    common_dataloader = DataLoader(commonObjDataset, batch_size=b, shuffle=True)
    forget_dataloader = DataLoader(obj2forgetAngleDataset, batch_size=b, shuffle=False)

    # forget_iterator = itertools.cycle(forget_dataloader)
    common_iterator = itertools.cycle(common_dataloader)

    dist.print0('finish loading zero123 dataset!!!')


    '''
        重要！！！！！！！！！！！
        采用水平角度采样，正前正后正左正右四个基准角度
        生成基准角度的两个context
    '''
    # front_angle = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
    # behind_angle = torch.tensor([0.0, 180.0, 0.0], dtype=torch.float32)
    # left_angle = torch.tensor([0.0, -90.0, 0.0], dtype=torch.float32)
    # right_angle = torch.tensor([0.0, 90.0, 0.0], dtype=torch.float32)
    # up_angle = torch.tensor([90.0, 0.0, 0.0], dtype=torch.float32)
    # down_angle = torch.tensor([-90.0, 0.0, 0.0], dtype=torch.float32)
    # horizontal_angles = torch.stack([front_angle, behind_angle, left_angle, right_angle])
    horizontal_angles = generate_main_view_angles(4)
    print(horizontal_angles)
    # vertical_angles = torch.stack([up_angle, down_angle])
    img = Image.open(forget_img_path).convert("RGBA")
    img_tensor = img_preprocess(img, carvekit_model)
    img_tensor = torch.stack([img_tensor], dim=0)
    # vertical_img_tensor = img_tensor.expand(vertical_angles.shape[0], -1, -1, -1)
    horizontal_img_tensor = img_tensor.expand(horizontal_angles.shape[0], -1, -1, -1)
    # vertical_context, vertical_img_context = get_context(model=model, input_image=vertical_img_tensor, cam_angle=vertical_angles,
    #                                                      guidance_scale=3, device=device)
    horizontal_context, horizontal_img_context = get_context(model=model, input_image=horizontal_img_tensor,
                                                             cam_angle=horizontal_angles,
                                                             guidance_scale=3, device=device)
    # print(f'vertical context shape: {vertical_context.shape}\nhorizontal context shape: {horizontal_context.shape}')


    unet = unet.eval()
    true_score = unet
    true_score.eval().requires_grad_(False).to(device)
    # pkl_path = 'image_experiment/df2-stage2-train-runs/00001-aesthetics-text_cond-glr6e-06-lr4e-06-initsigma625-gpus1-alpha1.0-batch1-tmax980-fp16/network-snapshot-mod5-1.000000-000010.pkl'
    # data = torch.load(pkl_path, map_location=torch.device(device))
    # fake_score = data['sg']
    fake_score = copy.deepcopy(unet).eval()
    fake_score.eval().requires_grad_(False).to(device)
    # G = data['ema']
    G = copy.deepcopy(unet).eval()
    G.eval().requires_grad_(False).to(device)

    fake_score_optimizer = dnnlib.util.construct_class_by_name(params=fake_score.parameters(),
                                                               **fake_score_optimizer_kwargs)  # subclass of torch.optim.Optimizer
    g_optimizer = dnnlib.util.construct_class_by_name(params=G.parameters(),
                                                      **g_optimizer_kwargs)  # subclass of torch.optim.Optimizer
    img_shape = [b, 4, 32, 32]
    fake_score.eval().requires_grad_(False)
    G.eval().requires_grad_(False)

    # start training
    # Train.
    dist.print0(f'Training for {total_kimg} kimg...')
    dist.print0()
    cur_nimg = resume_kimg * 1000
    cur_tick = 0
    tick_start_nimg = cur_nimg
    dist.update_progress(cur_nimg // 1000, total_kimg)
    stats_jsonl = None

    while True:
        epoch_start = time.time()
        '''每一轮训练前先对四个基准角度的噪声进行采样记录，每一步噪声都记录到sample_noise中'''
        with torch.no_grad():
            horizontal_sample_noise = zero123_sampler_noise_by_step(unet=G, sampler=sampler, device=device,
                                                                    context=horizontal_context,
                                                                    img_context=horizontal_img_context,
                                                                    dtype=torch.float32, guidance_scale=guidance_scale,
                                                                    num_steps=32)
            # vertical_sample_noise = zero123_sampler_noise_by_step(unet=G, sampler=sampler, device=device,
            #                                                         context=vertical_context,
            #                                                         img_context=vertical_img_context,
            #                                                         dtype=torch.float32, guidance_scale=guidance_scale,
            #                                                         num_steps=32)

        inner_loop = 0
        for forget_img, override_img, forget_angle in forget_dataloader:
            inner_start = time.time()
            common_img, common_angle = next(common_iterator)
            start_time = time.time()
            # first stage: train fake score network
            torch.cuda.empty_cache()
            gc.collect()
            G.eval().requires_grad_(False)

            fake_score.train().requires_grad_(True)
            fake_score_optimizer.zero_grad(set_to_none=True)
            # initialize losses
            sg_remain_loss_print = sg_forget_loss_print = 0
            start_time = time.time()
            ''''''
            with torch.no_grad():
                context, img_context = get_context(model=model, input_image=common_img, cam_angle=common_angle,
                                                   guidance_scale=guidance_scale, device=device)

            # generate fake images (generator no grad)
            with torch.no_grad():
                sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                                              context=context, img_context=img_context, init_timesteps=1000,
                                              return_images=False, num_steps=32, train_sampler=True, num_steps_eval=32)

            noise = torch.randn(img_shape).to(device)
            timesteps = torch.randint(tmin, tmax, (len(context),), device=device, dtype=torch.long)
            # Compute remain loss for fake score network
            # 计算出要保留的类别对于fake score network的分数
            noise_fake = zero123_denoise(unet=fake_score, sampler=sampler,
                                         sampler_noise=sampler_img, noise=noise, predict_x0=False,
                                         context=context, img_context=img_context, timesteps=timesteps,
                                         guidance_scale=guidance_scale, device=device)
            end_time = time.time()

            with torch.no_grad():
                nan_mask = torch.isnan(noise_fake).flatten(start_dim=1).any(dim=1)

            # Check if there are any NaN values present
            target = None
            if nan_mask.any():
                # Invert the nan_mask to get a mask of samples without NaNs
                non_nan_mask = ~nan_mask
                # Filter out samples with NaNs from y_real and y_fake
                noise_fake = noise_fake[non_nan_mask]
                noise = noise[non_nan_mask]

            sg_remain_loss = (noise_fake - noise) ** 2
            sg_remain_loss = sg_remain_loss.sum().mul(loss_scaling)

            if len(noise) > 0:
                print('sg_remain_loss: ', sg_remain_loss)
                sg_remain_loss.mul(sg_remain_coef).backward()
            else:
                print('sg_remain_loss no backward')

            del sampler_img, target
            del noise_fake

            sg_remain_loss_print += sg_remain_loss.detach().item()
            del sg_remain_loss

            # score network有需要遗忘的内容
            if sg_forget_coef > 0:
                # forget_prompt - brad pitt ... , override_prompt - a middle aged man ..., 没有neg
                # img, override_img, angle = next(forget_iterator)
                # img = img_preprocess(img, carvekit_model)
                with torch.no_grad():
                    forget_context, forget_img_context = get_context(model=model, cam_angle=forget_angle,
                                                                     input_image=forget_img,
                                                                     guidance_scale=guidance_scale, device=device)
                    selected_similarities, selected_indices = compare_cosine_similarity(horizontal_context, forget_context)
                    # print(horizontal_context.shape)
                    # h_selected_sim, h_selected_indices = compare_cosine_similarity(horizontal_context, forget_context)
                    # v_selected_sim, v_selected_indices = compare_cosine_similarity(vertical_context, forget_context)
                    print(f'noise cache shape: {horizontal_sample_noise[skip_rounds].shape}')
                    # interpolation_noises = cal_interpolation_noise(selected_similarities, selected_indices,
                    #                                                noise_cache=horizontal_sample_noise[skip_rounds], lb=0.7)
                    interpolation_noise = cal_interpolation_noise(selected_similarities, selected_indices,
                                                                  horizontal_sample_noise[skip_rounds], lb=0.7)

                with torch.no_grad():
                    sampler_img = zero123_sampler_noise_skip(unet=G, sampler=sampler, device=device,
                                                                 context=forget_context, img_context=forget_img_context,
                                                                 guidance_scale=3, num_steps=32, skip_start_idx=skip_start,
                                                                 num_skip_steps=skip_rounds,
                                                                 alter_noise=interpolation_noise, return_images=False)
                    # sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                    #                               context=forget_context, img_context=forget_img_context,
                    #                               return_images=False, num_steps=32, train_sampler=True,
                    #                               num_steps_eval=32)

                noise = torch.randn(img_shape).to(device)
                timesteps = torch.randint(tmin, tmax, (len(forget_context),), device=device, dtype=torch.long)
                # Compute forget loss for fake score network
                # Denoised fake images (stop generator gradient) under fake score network, using guidance scale: kappa1=cfg_eval_train

                noise_fake = zero123_denoise(unet=fake_score, sampler=sampler,
                                             sampler_noise=sampler_img, noise=noise, predict_x0=False,
                                             context=forget_context, img_context=forget_img_context,
                                             timesteps=timesteps,
                                             guidance_scale=guidance_scale, device=device)

                with torch.no_grad():
                    nan_mask = torch.isnan(noise_fake).flatten(start_dim=1).any(dim=1)

                # Check if there are any NaN values present
                target = None
                if nan_mask.any():
                    # Invert the nan_mask to get a mask of samples without NaNs
                    non_nan_mask = ~nan_mask
                    # Filter out samples with NaNs from y_real and y_fake
                    noise_fake = noise_fake[non_nan_mask]
                    noise = noise[non_nan_mask]

                sg_forget_loss = (noise_fake - noise) ** 2

                sg_forget_loss = sg_forget_loss.sum().mul(loss_scaling)

                if len(noise) > 0:
                    print('sg_forget_loss: ', sg_forget_loss)
                    sg_forget_loss.mul(sg_forget_coef).backward()
                else:
                    print('sg_forget_loss no backward')

                del sampler_img, target
                del noise_fake

                sg_forget_loss = sg_forget_loss.detach().cpu().item()
                sg_forget_loss_print += sg_forget_loss

                del sg_forget_loss

            training_stats.report('fake_score_Loss/remain_loss', sg_remain_loss_print)
            training_stats.report('fake_score_Loss/forget_loss', sg_forget_loss_print)

            fake_score.eval().requires_grad_(False)

            # Update fake score network
            for param in fake_score.parameters():
                if param.grad is not None:
                    torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)

            fake_score_optimizer.step()

            # ----------------------------------------------------------------------------------------------
            # Update One-Step Generator Network

            true_score.eval().requires_grad_(False)
            true_score.eval().requires_grad_(False)
            G.train().requires_grad_(True)
            g_optimizer.zero_grad(set_to_none=True)

            g_remain_loss_print = g_forget_loss_print = 0
            # 此处的context是经过把brad pitt 替换为 middle aged man后的prompt
            # img, angle = next(common_iterator)
            start_time = time.time()

            with torch.no_grad():
                context, img_context = get_context(model=model, cam_angle=common_angle, input_image=common_img,
                                                   guidance_scale=guidance_scale, device=device)

            sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                                          context=context, img_context=img_context,
                                          return_images=False, num_steps=32, train_sampler=True, num_steps_eval=32)

            noise = torch.randn(img_shape)
            timesteps = torch.randint(tmin, tmax, (len(context),), device=device, dtype=torch.long)
            # Compute loss for generator

            # with torch.no_grad():

            sampler_img.requires_grad_(True)
            with torch.no_grad():
                y_fake = zero123_denoise(unet=fake_score, sampler=sampler,
                                         sampler_noise=sampler_img, noise=noise,
                                         context=context, img_context=img_context, timesteps=timesteps,
                                         guidance_scale=guidance_scale, device=device)

                # Denoised fake images (track generator gradient) under pretrained score network, using guidance scale: kappa4=cfg_eval_real
                y_real = zero123_denoise(unet=true_score, sampler=sampler,
                                         sampler_noise=sampler_img, noise=noise,
                                         context=context, img_context=img_context, timesteps=timesteps,
                                         guidance_scale=guidance_scale, device=device)

            end_time = time.time()

            with torch.no_grad():
                nan_mask_images = torch.isnan(sampler_img).flatten(start_dim=1).any(dim=1)
                nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
                nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
                nan_mask = nan_mask_images | nan_mask_y_real | nan_mask_y_fake

            # Check if there are any NaN values present
            if nan_mask.any():
                # Invert the nan_mask to get a mask of samples without NaNs
                non_nan_mask = ~nan_mask
                # Filter out samples with NaNs from y_real and y_fake
                sampler_img = sampler_img[non_nan_mask]
                y_real = y_real[non_nan_mask]
                y_fake = y_fake[non_nan_mask]

            with torch.no_grad():
                weight_factor = abs(sampler_img.to(torch.float32) - y_real.to(torch.float32)).mean(
                    # dim=[1, 2, 3], keepdim=True).clip(min=0.00001)
                    dim=[1, 2, 3], keepdim=True).clip(min=0.01)

            if alpha == 1:
                g_remain_loss = (y_real - y_fake) * (y_fake - sampler_img) / weight_factor
            else:
                g_remain_loss = (y_real - y_fake) * (
                        (y_real - sampler_img) - alpha * (y_real - y_fake)) / weight_factor

            g_remain_loss = g_remain_loss.sum().mul(loss_scaling_G)

            if (~nan_mask).sum().item() > 0:
                g_remain_loss = g_remain_loss * g_remain_coef  # 确保所有张量都启用了 requires_grad
                if not g_remain_loss.requires_grad:
                    g_remain_loss.requires_grad_(True)
                print('g_remain_loss: ', g_remain_loss)
                g_remain_loss.backward()
                torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1.0)
            else:
                print('g_remain_loss no backward')

            g_remain_loss = g_remain_loss.detach().cpu().item()
            g_remain_loss_print += g_remain_loss
            del y_real, y_fake, sampler_img, g_remain_loss

            if g_forget_coef > 0:
                # generator network有要遗忘的内容
                with torch.no_grad():
                    forget_context, forget_img_context = get_context(model=model, cam_angle=forget_angle,
                                                                     input_image=forget_img,
                                                                     guidance_scale=guidance_scale, device=device)

                    override_angle = copy.deepcopy(forget_angle)
                    # blurred_img = add_gaussian_blur_torch('zero123_dataset/object_to_forget_angle/image/minion_backpack.png',
                    #                                       blur_kernel=45).convert('RGBA')
                    # blurred_img = Image.open('zero123_dataset/object_to_forget_angle/image/chair_override.png').convert('RGBA')
                    # blurred_img_tensor = img_preprocess(blurred_img, carvekit_model)
                    # blurred_img_tensor = torch.stack([blurred_img_tensor], dim=0)

                    # front_img = Image.open('zero123_dataset/object_to_forget_angle/image/minion_pixel_front.png').convert('RGBA')
                    # front_img_tensor = img_preprocess(front_img, carvekit_model)
                    # back_img = Image.open('zero123_dataset/object_to_forget_angle/image/minion_pixel_back.png').convert('RGBA')
                    # back_img_tensor = img_preprocess(back_img, carvekit_model)
                    # left_img = Image.open('zero123_dataset/object_to_forget_angle/image/minion_pixel_left.png').convert('RGBA')
                    # left_img_tensor = img_preprocess(left_img, carvekit_model)
                    # right_img = Image.open('zero123_dataset/object_to_forget_angle/image/minion_pixel_right.png').convert('RGBA')
                    # right_img_tensor = img_preprocess(right_img, carvekit_model)
                    # pixel_img_tensor = torch.stack([front_img_tensor, back_img_tensor, left_img_tensor, right_img_tensor], dim=0)
                    #
                    # for idx in range(len(forget_angle)):
                    #     if -45 <= override_angle[idx, 1] <= 45:
                    #         override_img[idx] = pixel_img_tensor[0]
                    #     elif override_angle[idx, 1] >= 135 or override_angle[idx, 1] <= -135:
                    #         override_img[idx] = pixel_img_tensor[1]
                    #         if override_angle[idx, 1] >= 135:
                    #             override_angle[idx, 1] -= 180
                    #         else:
                    #             override_angle[idx, 1] += 180
                    #     elif -135 < override_angle[idx, 1] < -45:
                    #         override_img[idx] = pixel_img_tensor[2]
                    #         override_angle[idx, 1] += 90
                    #     else:
                    #         override_img[idx] = pixel_img_tensor[3]
                    #         override_angle[idx, 1] -= 90

                    # for idx in range(len(forget_angle)):
                    #     if override_angle[idx, 1] > 90 or override_angle[idx, 1] < -90:
                    #         print(f'blur img back, y_angle : {override_angle[idx, 1]}')
                    #         override_img[idx] = blurred_img_tensor[0]
                    #         if override_angle[idx, 1] > 90:
                    #             override_angle[idx, 1] -= 180
                    #         else:
                    #             override_angle[idx, 1] += 180
                    #     else:
                    #         print(f'do not blur angle {override_angle[idx, 1]}')

                    # # 正 -> 左； 左 -> 正 ； 后 -> 右；右 ->  后;
                    # # 0 -> -90, 90 -> 0, -90 -> 180, -1 -> -91, -45 -> -135, -135 -> 135
                    # # -135~180 135~180不换
                    # for idx in range(len(forget_angle)):
                    #     if -135 < override_angle[idx, 1] < 135:
                    #         override_img[idx] = blurred_img_tensor[0]
                    #         if -90 < override_angle[idx, 1]:
                    #             override_angle[idx, 1] -= 90
                    #         else:
                    #             override_angle[idx, 1] += 270
                    #     else:
                    #         print(f'do not blur angle {override_angle[idx, 1]}')

                    override_context, override_img_context = get_context(model=model, cam_angle=override_angle,
                                                                         input_image=override_img,
                                                                         guidance_scale=guidance_scale, device=device)
                '''
                ！！！！！！！！！！！！！！！！！！！
                核心改进：
                1.计算当前角度图像的context与四个基准角度的context的余弦相似度选，择要采样两个的基准角度
                2.利用基准角度的采样噪声根据余弦相似度大小对跳步的噪声进行插值
                3.生成器采样时直接跳过指定的跳步步数，用插值后的噪声直接作为跳步后的采样噪声
                ！！！！！！！！！！！！！！！！！！！
                '''
                selected_similarities, selected_indices = compare_cosine_similarity(base_context=horizontal_context, current_context=forget_context)
                interpolation_noise = cal_interpolation_noise(selected_similarities, selected_indices,
                                                              noise_cache=horizontal_sample_noise[skip_rounds], lb=0.7)
                # interpolation_noises = torch.stack([cal_interpolation_noise(selected_similarities, selected_indices,
                #                          noise_cache=horizontal_sample_noise[idx], lb=0.7) for idx in range(skip_rounds)])
                sampler_img = zero123_sampler_noise_skip(unet=G, sampler=sampler, device=device, context=forget_context, img_context=forget_img_context,
                                               guidance_scale=guidance_scale, num_steps=32, skip_start_idx=skip_start,
                                                         num_skip_steps=skip_rounds, alter_noise=interpolation_noise,
                                                         return_images=False, train_sampler=True)
                # sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                #                               context=forget_context, img_context=forget_img_context,
                #                               return_images=False, num_steps=32, train_sampler=True, num_steps_eval=32)

                noise = torch.randn(img_shape)
                timesteps = torch.randint(tmin, tmax, (len(forget_context),), device=device, dtype=torch.long)
                # Compute loss for generator
                # Denoised fake images (track generator gradient) under fake score network, using guidance scale: kappa2=kappa3=cfg_eval_fake
                # 重点理解！！！
                with torch.no_grad():
                    y_fake = zero123_denoise(unet=fake_score, device=device, sampler=sampler,
                                             sampler_noise=sampler_img, noise=noise,
                                             context=forget_context, img_context=forget_img_context,
                                             timesteps=timesteps,
                                             guidance_scale=guidance_scale)

                    # Denoised fake images (track generator gradient) under pretrained score network, using guidance scale: kappa4=cfg_eval_real
                    y_real = zero123_denoise(unet=true_score, sampler=sampler,
                                             sampler_noise=sampler_img, noise=noise,
                                             context=override_context, img_context=override_img_context,
                                             timesteps=timesteps,
                                             guidance_scale=guidance_scale, device=device)

                with torch.no_grad():
                    nan_mask_images = torch.isnan(sampler_img).flatten(start_dim=1).any(dim=1)
                    nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
                    nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
                    nan_mask = nan_mask_images | nan_mask_y_real | nan_mask_y_fake

                # Check if there are any NaN values present
                if nan_mask.any():
                    # Invert the nan_mask to get a mask of samples without NaNs
                    non_nan_mask = ~nan_mask
                    # Filter out samples with NaNs from y_real and y_fake
                    sampler_img = sampler_img[non_nan_mask]
                    y_real = y_real[non_nan_mask]
                    y_fake = y_fake[non_nan_mask]

                with torch.no_grad():
                    weight_factor = abs(sampler_img.to(torch.float32) - y_real.to(torch.float32)).mean(
                        dim=[1, 2, 3], keepdim=True).clip(min=0.01)

                # weight_factor = torch.tensor(weight_factor, dtype=torch.float32, requires_grad=True)
                if alpha == 1:
                    g_forget_loss = (y_real - y_fake) * (y_fake - sampler_img) / weight_factor
                else:
                    g_forget_loss = (y_real - y_fake) * (
                            (y_real - sampler_img) - alpha * (y_real - y_fake)) / weight_factor

                g_forget_loss = g_forget_loss.sum().mul(loss_scaling_G)

                if (~nan_mask).sum().item() > 0:
                    g_forget_loss = g_forget_loss * g_remain_coef  # 确保所有张量都启用了 requires_grad
                    if not g_forget_loss.requires_grad:
                        g_forget_loss.requires_grad_(True)
                    g_forget_loss.backward()
                    torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1.0)
                    print('g_forget_loss: ', g_forget_loss)
                else:
                    print('g_forget_loss no backward')

                g_forget_loss = g_forget_loss.detach().cpu().item()
                g_forget_loss_print += g_forget_loss

                del y_real, y_fake, sampler_img, g_forget_loss

                training_stats.report('G_Loss/remain_loss', g_remain_loss_print)
                training_stats.report('G_Loss/forget_loss', g_forget_loss_print)

                G.eval().requires_grad_(False)

                # Update generator
                for param in G.parameters():
                    if param.grad is not None:
                        torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)

                g_optimizer.step()

                torch.cuda.empty_cache()
                gc.collect()

                end_time = time.time()
                inner_end = time.time()
                inner_loop += 1
                print(f'cur_nimg:{cur_nimg} : inner loop: {inner_loop}/{len(forget_dataloader)}, time: {inner_end - inner_start}')

        epoch_end = time.time()
        cur_nimg += 1
        done = (cur_nimg >= total_kimg * 1000)

        print(f'cur_tick: {cur_tick}, cur_nimg: {cur_nimg}, time spent: {epoch_end - epoch_start}')
        # if cur_nimg != 50 and cur_nimg != 5:
        # if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + 10):
        #     continue

        # Print status line, accumulating the same information in training_stats.
        tick_end_time = time.time()
        fields = []
        fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
        fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
        fields += [
            f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2 ** 30):<6.2f}"]
        fields += [
            f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2 ** 30):<6.2f}"]
        fields += [
            f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2 ** 30):<6.2f}"]
        fields += [
            f"loss_fake_score_remain {training_stats.report0('fake_score_Loss/remain_loss', sg_remain_loss_print):<6.2f}"]
        fields += [
            f"loss_fake_score_forget {training_stats.report0('fake_score_Loss/forget_loss', sg_forget_loss_print):<6.2f}"]
        fields += [f"loss_G_remain {training_stats.report0('G_Loss/remain_loss', g_remain_loss_print):<6.2f}"]
        fields += [f"loss_G_forget {training_stats.report0('G_Loss/forget_loss', g_forget_loss_print):<6.2f}"]
        torch.cuda.reset_peak_memory_stats()
        dist.print0(' '.join(fields))

        # Check for abort.
        if (not done) and dist.should_stop():
            done = True
            dist.print0()
            dist.print0('Aborting...')

        # if (snapshot_ticks is not None) and (
        #         done or cur_tick % snapshot_ticks == 0 or cur_tick in [1, 2, 4, 10, 20, 30, 40, 50, 60, 70, 80, 90,
        #                                                                100]):
        # if (cur_nimg % 5 == 0 and cur_nimg >= 50) or (cur_nimg == 50):
        if (cur_nimg % 5 == 0) or (cur_nimg < 20):

            dist.print0('Exporting sample images...')
            # grid_size = (32, 32)
            # forget_img, override_img, angle = forget_dataloader[0]
            # context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle,
            #                                    guidance_scale=3, device=device)
            # with torch.no_grad():
            #     images = zero123_sampler(unet=G, device=device, sampler=sampler,
            #                              context=context, img_context=img_context, guidance_scale=3,
            #                              return_images=True, num_steps=32, train_sampler=False,
            #                              num_steps_eval=32)
            #
            # images = images.cpu().permute(0, 2, 3, 1).clamp(0, 1).numpy()
            # fig, axes = plt.subplots(2, 2, figsize=(8, 8))
            # for i, ax in enumerate(axes.flat):
            #     ax.imshow(images[i])
            #     ax.axis('off')
            # plt.tight_layout()
            # # plt.savefig(f'test_grid_{cur_nimg}.png')
            # plt.savefig(os.path.join(run_dir, f'test_grid_{cur_nimg}.png'))

            for y_angle in [45, 90, 135, 180, -45, -90, -135, 0]:
                # 测试正常生成图片
                test_angle = torch.tensor([0.0, y_angle, 0.0], dtype=torch.float32).unsqueeze(0)
                test_context, test_img_context = get_context(model=model, input_image=img_tensor,
                                                             cam_angle=test_angle,
                                                             guidance_scale=3, device=device)
                test_sample_img = zero123_sampler(unet=G, sampler=sampler, device=device, context=test_context,
                                                  img_context=test_img_context, guidance_scale=3, num_steps=32,
                                                  train_sampler=False, num_steps_eval=32, return_images=True)
                image_tensor = torch.clamp((test_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
                sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
                img = Image.fromarray(sample.astype(np.uint8))
                img.save(os.path.join(run_dir, f'epoch_{cur_nimg:03d}_sample_angle_{y_angle}.png'))

                # 测试跳步图片
                selected_similarities, selected_indices = compare_cosine_similarity(horizontal_context, test_context)
                # interpolation_noises = cal_interpolation_noise(selected_similarities, selected_indices,
                #                          noise_cache=horizontal_sample_noise[skip_rounds], lb=0.7)
                interpolation_noise = cal_interpolation_noise(selected_similarities, selected_indices,
                                                              noise_cache=horizontal_sample_noise[skip_rounds], lb=0.7)
                skip_sample_img = zero123_sampler_noise_skip(unet=G, sampler=sampler, device=device,
                                                             context=test_context, img_context=test_img_context,
                                                             guidance_scale=3, num_steps=32, skip_start_idx=0,
                                                             num_skip_steps=skip_rounds,
                                                             alter_noise=interpolation_noise, return_images=True)
                image_tensor = torch.clamp((skip_sample_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
                sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
                img = Image.fromarray(sample.astype(np.uint8))
                img.save(os.path.join(run_dir, f'epoch_{cur_nimg:03d}_skip_{skip_rounds}_angle_{y_angle}.png'))
                print('saved')

            # if cur_nimg >= 50:
            if True:
                G_save = copy.deepcopy(G).to('cpu')
                sg_save = copy.deepcopy(fake_score).to('cpu')
                data = dict(ema=G_save, sg=sg_save)

                if dist.get_rank() == 0:
                    torch.save(data,
                               os.path.join(run_dir, f'network-snapshot-mod5-{alpha:03f}-{cur_nimg :06d}.pkl'))
                    # save_data(data=data,
                    #           fname=os.path.join(run_dir, f'network-snapshot-mod5-{alpha:03f}-{cur_nimg :06d}.pkl'))

                del data, G_save  # conserve memory

        # if (state_dump_ticks is not None) and (
        #         done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
        #     dist.print0(f'saving checkpoint: training-state-{cur_nimg // 1000:06d}.pt')
        #     save_pt(pt=dict(fake_score=fake_score, G=G, G_ema=G_ema,
        #                     fake_score_optimizer_state=fake_score_optimizer.state_dict(),
        #                     g_optimizer_state=g_optimizer.state_dict()),
        #             fname=os.path.join(run_dir, f'training-state-{cur_nimg // 1000:06d}.pt'))

        # Update logs.
        training_stats.default_collector.update()
        if dist.get_rank() == 0:
            if stats_jsonl is None:
                append_line(jsonl_line=json.dumps(
                    dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n',
                            fname=os.path.join(run_dir, f'stats_{alpha:03f}.jsonl'))

        dist.update_progress(cur_nimg // 1000, total_kimg)

            # Update state.
        cur_tick += 1
        tick_start_nimg = cur_nimg
        tick_start_time = time.time()
        maintenance_time = tick_start_time - tick_end_time
        if done:
            break

        # Done.
        dist.print0()
        dist.print0('Exiting...')

# ----------------------------------------------------------------------------
