"""Distill Stable Diffusion models using the SiD-LSG techniques described in the
paper "Long and Short Guidance in Score identity Distillation for One-Step Text-to-Image Generation"."""
import itertools

from torch.utils.data import DataLoader

"""Main training loop."""
import re
import os
import time
import copy
import json
import pickle
import psutil
import PIL.Image
import numpy as np
import torch
import dnnlib
from torch_utils import distributed as dist
from torch_utils import training_stats
from torch_utils import misc
import torch.nn as nn
from functools import partial
import gc
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
# Needed for v-prediction based diffusion model
from diffusers.training_utils import compute_snr

# Functions needed to integrate Stable Diffusion into SiD
from training.sd_util import load_sd15, sid_sd_sampler, sid_sd_denoise

# zero 123 dataset
from training.zero123_datasets import ObjectTOForgetAngleDataset, CommonObjectDataset
from load_model import load_zero_123, img_preprocess, get_context, zero123_sampler, zero123_denoise, load_unet_from_ckpt
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import create_carvekit_interface


from torchvision.utils import save_image
from einops import rearrange
# ----------------------------------------------------------------------------
def setup_snapshot_image_grid(training_set, random_seed=0):
    gw = np.clip(4096 // training_set.resolution, 8, 32)
    gh = np.clip(2048 // training_set.resolution, 4, 32)
    all_indices = list(range(len(training_set)))

    if random_seed is not None:
        np.random.RandomState(random_seed).shuffle(all_indices)

    _gw = gw // 2
    grid_indices = [all_indices[i % len(all_indices)] for i in range(_gw * gh)]

    contexts = []
    for i in grid_indices:
        contexts.extend([training_set[i][0], training_set[i][0]])

    return (gw, gh), None, contexts


from itertools import islice


def split_list(lst, split_sizes):
    """
    Splits a list into chunks based on split_sizes.

    Parameters:
    - lst (list): The list to be split.
    - split_sizes (list or int): Sizes of the chunks to split the list into.
                                 If it's an integer, the list will be divided into chunks of this size.
                                 If it's a list of integers, the list will be divided into chunks of varying sizes specified by the list.

    Returns:
    - list of lists: The split list.
    """
    if isinstance(split_sizes, int):
        # If split_sizes is an integer, create a list of sizes to split the list evenly, except the last chunk which may be smaller.
        split_sizes = [split_sizes] * (len(lst) // split_sizes) + (
            [len(lst) % split_sizes] if len(lst) % split_sizes != 0 else [])
    it = iter(lst)
    return [list(islice(it, size)) for size in split_sizes]


from PIL import Image

def save_image_grid(img, fname, drange, grid_size):
    lo, hi = drange
    img = np.asarray(img, dtype=np.float32)
    img = (img - lo) * (255 / (hi - lo))
    img = np.rint(img).clip(0, 255).astype(np.uint8)

    gw, gh = grid_size
    _N, C, H, W = img.shape
    img = img.reshape(gh, gw, C, H, W)
    img = img.transpose(0, 3, 1, 4, 2)
    img = img.reshape(gh * H, gw * W, C)

    assert C in [1, 3]
    if C == 1:
        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
    if C == 3:
        PIL.Image.fromarray(img, 'RGB').save(fname)


def save_pil_images_in_grid(image_files, grid_size, output_path):
    gw, gh = grid_size
    # Assuming all images are the same size, open the first image to get its size
    image_width, image_height = image_files[0].size

    # Calculate the total grid size
    grid_width = gw * image_width
    grid_height = gh * image_height

    # Create a new blank image for the grid
    grid_image = Image.new('RGB', (grid_width, grid_height))

    # Iterate over the images and paste them into the grid
    for index, image in enumerate(image_files):
        # Calculate the position based on the index
        x = (index % gw) * image_width
        y = (index // gw) * image_height
        grid_image.paste(image, (x, y))

    # Save the final grid image
    grid_image.save(output_path)


# ----------------------------------------------------------------------------
# Helper methods


def save_image_grid(img, fname, drange, grid_size):
    lo, hi = drange
    img = np.asarray(img, dtype=np.float32)
    img = (img - lo) * (255 / (hi - lo))
    img = np.rint(img).clip(0, 255).astype(np.uint8)

    gw, gh = grid_size
    _N, C, H, W = img.shape
    img = img.reshape(gh, gw, C, H, W)
    img = img.transpose(0, 3, 1, 4, 2)
    img = img.reshape(gh * H, gw * W, C)

    assert C in [1, 3]
    if C == 1:
        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
    if C == 3:
        PIL.Image.fromarray(img, 'RGB').save(fname)


def save_data(data, fname):
    with open(fname, 'wb') as f:
        pickle.dump(data, f)


def save_pt(pt, fname):
    torch.save(pt, fname)


def append_line(jsonl_line, fname):
    with open(fname, 'at') as f:
        f.write(jsonl_line + '\n')


def upcast_lora_params(model, dtype):
    if dtype == torch.float16:
        for param in model.parameters():
            if param.requires_grad:
                param.data = param.to(torch.float32)


# ----------------------------------------------------------------------------

def training_loop(
        run_dir='.',  # Output directory.
        dataset_kwargs={},  # Options for training set.
        data_loader_kwargs={},  # Options for torch.utils.data.DataLoader.
        network_kwargs={},  # Options for model and preconditioning.
        loss_kwargs={},  # Options for loss function.
        fake_score_optimizer_kwargs={},  # Options for fake score network optimizer.
        g_optimizer_kwargs={},  # Options for generator optimizer.
        augment_kwargs=None,  # Options for augmentation pipeline, None = disable.
        seed=0,  # Global random seed.
        batch_size=1,  # Total batch size for one training iteration.
        batch_gpu=None,  # Limit batch size per GPU, None = no limit.
        total_kimg=200000,  # Training duration, measured in thousands of training images.
        ema_halflife_kimg=500,  # Half-life of the exponential moving average (EMA) of model weights.
        ema_rampup_ratio=0.05,  # EMA ramp-up coefficient, None = no rampup.
        loss_scaling=1,  # Loss scaling factor, could be adjusted for reducing FP16 under/overflows.
        loss_scaling_G=1,  # Loss scaling factor of G, could be adjusted for reducing FP16 under/overflows.
        kimg_per_tick=0.01,  # Interval of progress prints.
        snapshot_ticks=50,  # How often to save network snapshots, None = disable.
        state_dump_ticks=500,  # How often to dump training state, None = disable.
        resume_pkl=None,  # Start from the given network snapshot for initialization, None = random initialization.
        resume_training=None,  # Resume training from the given network snapshot.
        resume_kimg=0,  # Start from the given training progress.
        alpha=1,  # loss = L2-alpha*L1
        tmax=980,  # We add noise at steps 0 to tmax, tmax <= 1000
        tmin=20,  # We add noise at steps 0 to tmax, tmax <= 1000
        cudnn_benchmark=True,  # Enable torch.backends.cudnn.benchmark?
        device=torch.device('cuda'),
        init_timestep=None,
        pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5",
        fake_score_use_lora=False,
        dataset_prompt_text_kwargs={},
        forget_dataset_prompt_text_kwargs={},
        cfg_train_fake=1,  # kappa1
        cfg_eval_fake=1,  # kappa2 = kappa3
        cfg_eval_real=1,  # kappa4
        num_steps=1,
        train_mode=True,
        enable_xformers=True,
        gradient_checkpointing=False,
        resolution=512,
        sg_remain_coef=1.0,
        sg_forget_coef=0.01,
        g_remain_coef=1.0,
        g_forget_coef=0.01,
        from_distill_ema=None,
        sid_w_neg=False,
        use_neg=(False, False, True),
        sg_w_override=False,
        pretrained_vae_model_name_or_path=None,

        guidance_scale=3,
        b=4,
        num_angles=4
):

    # 1: 1.28 1.17
    # 4: 2.8 2.9
    # 8: 5.31
    # load_model
    ckpt_path = '105000.ckpt'
    config_path = 'configs/sd-objaverse-finetune-c_concat-256.yaml'
    device = 'cuda'

    print('start to load model from: ', ckpt_path)
    model = load_zero_123(ckpt_path, config_path, device)
    print('whole model: ', model.__class__.__name__)

    # unet, vae, embedding, noise sampler
    unet = model.model.diffusion_model
    sampler = DDIMSampler(model)
    carvekit_model = create_carvekit_interface()

    # load zero 123 dataset
    commonObjDataset = CommonObjectDataset(image_dir='zero123_dataset/common_objects/image',
                                           csv_path='zero123_dataset/common_objects/cam_angle.csv',
        preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
                                           transform=None)
    print(len(commonObjDataset))
    # commonObjIter = iter(commonObjDataset)

    obj2forgetAngleDataset = ObjectTOForgetAngleDataset(
        forget_image_path='zero123_dataset/object_to_forget_angle/image/pikachu.png',
        override_image_path='zero123_dataset/object_to_forget_angle/image/minion.png',
        angles_csv_path='zero123_dataset/object_to_forget_angle/fixed_views.csv',
        preprocess_fn=img_preprocess, carvekit_model=carvekit_model,
        transform=None
    )
    print(len(obj2forgetAngleDataset))
    # obj2forgetIter = iter(obj2forgetAngleDataset)
    common_dataloader = DataLoader(commonObjDataset, batch_size=b, shuffle=True)
    forget_dataloader = DataLoader(obj2forgetAngleDataset, batch_size=b, shuffle=False)

    forget_iterator = itertools.cycle(forget_dataloader)
    common_iterator = itertools.cycle(common_dataloader)

    dist.print0('finish loading zero123 dataset!!!')

    unet = unet.eval()
    true_score = unet
    true_score.eval().requires_grad_(False).to(device)
    fake_score = copy.deepcopy(unet).eval()
    fake_score.eval().requires_grad_(False).to(device)
    G = copy.deepcopy(unet).eval()
    G.eval().requires_grad_(False).to(device)
    # true_score = load_unet_from_ckpt(device=device_0)
    # fake_score = load_unet_from_ckpt(device=device_1)
    # G = load_unet_from_ckpt(device=device_2)

    # def check_model_consistency(model1, model2):
    #     state_dict1 = model1.state_dict()
    #     state_dict2 = model2.state_dict()
    #
    #     for key in state_dict1.keys():
    #         if not torch.equal(state_dict1[key].cpu(), state_dict2[key].cpu()):
    #             print(f"Mismatch found in parameter: {key}")
    #             return False
    #     print("Models are consistent.")
    #     return True
    #
    # # 示例
    # check_model_consistency(true_score, fake_score)
    # check_model_consistency(true_score, G)
    # true_score.load_state_dict(unet.state_dict())
    # fake_score.load_state_dict(unet.state_dict())
    # G.load_state_dict(unet.state_dict())
    # true_score.load_state_dict(true_score.state_dict())
    # fake_score.load_state_dict(true_score.state_dict())
    # G.load_state_dict(true_score.state_dict())


    fake_score_optimizer = dnnlib.util.construct_class_by_name(params=fake_score.parameters(),
                                                               **fake_score_optimizer_kwargs)  # subclass of torch.optim.Optimizer
    g_optimizer = dnnlib.util.construct_class_by_name(params=G.parameters(),
                                                      **g_optimizer_kwargs)  # subclass of torch.optim.Optimizer
    img_shape = [b, 4, 32, 32]
    fake_score.eval().requires_grad_(False)
    G.eval().requires_grad_(False)

    # print('test initial generator')
    # # forget_img, override_img, angle = obj2forgetAngleDataset[0]
    # for i in range(num_angles):
    #     forget_img, override_img, angle = next(forget_iterator)
    #     # angle = [0., 0., 0.]
    #     # input_im = img_preprocess(forget_img, carvekit_model=carvekit_model)
    #     # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NEW CC_PROJECTION!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    #     context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle, guidance_scale=3, device=device)
    #     override_context, override_img_context = get_context(model=model, input_image=override_img, cam_angle=angle, guidance_scale=3, device=device)
    #     # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    #     with torch.no_grad():
    #         sampler_img = zero123_sampler(unet=G, sampler=sampler, device=device, guidance_scale=3,
    #                                       context=context, img_context=img_context, init_timesteps=init_timestep,
    #                                       return_images=True, num_steps=32, train_sampler=False, num_steps_eval=32)
    #     image_tensor = torch.clamp((sampler_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
    #     sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
    #     img = Image.fromarray(sample.astype(np.uint8))
    #     img.save(f'test_init_G_forget_{angle}.png')
    #     with torch.no_grad():
    #         sampler_img = zero123_sampler(unet=G, sampler=sampler, device=device, guidance_scale=3,
    #                                       context=override_context, img_context=override_img_context, init_timesteps=init_timestep,
    #                                       return_images=True, num_steps=32, train_sampler=False, num_steps_eval=32)
    #     image_tensor = torch.clamp((sampler_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
    #     sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
    #     img = Image.fromarray(sample.astype(np.uint8))
    #     img.save(f'test_init_G_override_{angle}.png')
    #
    #     del sampler_img

    # start training
    # Train.
    dist.print0(f'Training for {total_kimg} kimg...')
    dist.print0()
    cur_nimg = resume_kimg * 1000
    cur_tick = 0
    tick_start_nimg = cur_nimg
    dist.update_progress(cur_nimg // 1000, total_kimg)
    stats_jsonl = None

    sg_remain_loss_list = []
    sg_forget_loss_list = []
    g_remain_loss_list = []
    g_forget_loss_list = []
    while True:
        start_time = time.time()
        # first stage: train fake score network
        torch.cuda.empty_cache()
        gc.collect()
        G.eval().requires_grad_(False)

        fake_score.train().requires_grad_(True)
        fake_score_optimizer.zero_grad(set_to_none=True)
        # initialize losses
        sg_remain_loss_print = sg_forget_loss_print = 0
        # from CommonObjDataset
        img, angle = next(common_iterator)
        # img = img_preprocess(img, carvekit_model=carvekit_model)
        start_time = time.time()
        with torch.no_grad():
            context, img_context = get_context(model=model, input_image=img, cam_angle=angle, guidance_scale=guidance_scale, device=device)

        # generate fake images (generator no grad)
        with torch.no_grad():
            sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                                          context=context, img_context=img_context, init_timesteps=1000,
                                          return_images=False, num_steps=32, train_sampler=True, num_steps_eval=32)

        noise = torch.randn(img_shape).to(device)
        timesteps = torch.randint(tmin, tmax, (len(context),), device=device, dtype=torch.long)
        # Compute remain loss for fake score network
        # 计算出要保留的类别对于fake score network的分数
        noise_fake = zero123_denoise(unet=fake_score, sampler=sampler,
                                     sampler_noise=sampler_img, noise=noise, predict_x0=False,
                                     context=context, img_context=img_context, timesteps=timesteps,
                                     guidance_scale=guidance_scale, device=device)
        end_time = time.time()
        print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        print('noise fake time used: ', end_time-start_time, 'batch = ', b)
        print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')

        with torch.no_grad():
            nan_mask = torch.isnan(noise_fake).flatten(start_dim=1).any(dim=1)

        # Check if there are any NaN values present
        target = None
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            noise_fake = noise_fake[non_nan_mask]
            noise = noise[non_nan_mask]

        sg_remain_loss = (noise_fake - noise) ** 2
        sg_remain_loss = sg_remain_loss.sum().mul(loss_scaling)

        if len(noise) > 0:
            print('sg_remain_loss: ', sg_remain_loss)
            sg_remain_loss.mul(sg_remain_coef).backward()
        else:
            print('sg_remain_loss no backward')

        del sampler_img, target
        del noise_fake

        sg_remain_loss_print += sg_remain_loss.detach().item()
        sg_remain_loss_list.append(sg_remain_loss.detach().item())
        del sg_remain_loss

        # score network有需要遗忘的内容
        if sg_forget_coef > 0:
            # forget_prompt - brad pitt ... , override_prompt - a middle aged man ..., 没有neg
            img, override_img, angle = next(forget_iterator)
            # img = img_preprocess(img, carvekit_model)
            with torch.no_grad():
                forget_context, forget_img_context = get_context(model=model, cam_angle=angle, input_image=img,
                                                             guidance_scale=guidance_scale, device=device)

            with torch.no_grad():
                sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                                              context=forget_context, img_context=forget_img_context,
                                              return_images=False, num_steps=32, train_sampler=True, num_steps_eval=32)

            noise = torch.randn(img_shape).to(device)
            timesteps = torch.randint(tmin, tmax, (len(forget_context),), device=device, dtype=torch.long)
            # Compute forget loss for fake score network
            # Denoised fake images (stop generator gradient) under fake score network, using guidance scale: kappa1=cfg_eval_train

            noise_fake = zero123_denoise(unet=fake_score, sampler=sampler,
                                         sampler_noise=sampler_img, noise=noise, predict_x0=False,
                                         context=forget_context, img_context=forget_img_context,
                                         timesteps=timesteps,
                                         guidance_scale=guidance_scale, device=device)

            with torch.no_grad():
                nan_mask = torch.isnan(noise_fake).flatten(start_dim=1).any(dim=1)

            # Check if there are any NaN values present
            target = None
            if nan_mask.any():
                # Invert the nan_mask to get a mask of samples without NaNs
                non_nan_mask = ~nan_mask
                # Filter out samples with NaNs from y_real and y_fake
                noise_fake = noise_fake[non_nan_mask]
                noise = noise[non_nan_mask]

            sg_forget_loss = (noise_fake - noise) ** 2

            sg_forget_loss = sg_forget_loss.sum().mul(loss_scaling)

            if len(noise) > 0:
                print('sg_forget_loss: ', sg_forget_loss)
                sg_forget_loss.mul(sg_forget_coef).backward()
            else:
                print('sg_forget_loss no backward')

            del sampler_img, target
            del noise_fake

            sg_forget_loss = sg_forget_loss.detach().cpu().item()
            sg_forget_loss_print += sg_forget_loss
            sg_forget_loss_list.append(sg_forget_loss)

            del sg_forget_loss

        training_stats.report('fake_score_Loss/remain_loss', sg_remain_loss_print)
        training_stats.report('fake_score_Loss/forget_loss', sg_forget_loss_print)

        fake_score.eval().requires_grad_(False)

        # Update fake score network
        for param in fake_score.parameters():
            if param.grad is not None:
                torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)

        fake_score_optimizer.step()

        # ----------------------------------------------------------------------------------------------
        # Update One-Step Generator Network

        # fake_score.eval().requires_grad_(True)
        # true_score.eval().requires_grad_(True)
        true_score.eval().requires_grad_(False)
        true_score.eval().requires_grad_(False)
        G.train().requires_grad_(True)
        g_optimizer.zero_grad(set_to_none=True)

        g_remain_loss_print = g_forget_loss_print = 0
        # 此处的context是经过把brad pitt 替换为 middle aged man后的prompt
        img, angle = next(common_iterator)
        # img = img_preprocess(img, carvekit_model)
        start_time = time.time()

        with torch.no_grad():
            context, img_context = get_context(model=model, cam_angle=angle, input_image=img, guidance_scale=guidance_scale, device=device)

        sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                                      context=context, img_context=img_context,
                                      return_images=False, num_steps=32, train_sampler=True, num_steps_eval=32)

        noise = torch.randn(img_shape)
        timesteps = torch.randint(tmin, tmax, (len(context),), device=device, dtype=torch.long)
        # Compute loss for generator

        # with torch.no_grad():

        sampler_img.requires_grad_(True)
        # # 在计算 g_remain_loss 之前，临时存储 requires_grad 状态
        # true_score_grad_state = {}
        # fake_score_grad_state = {}
        #
        # # 保存原始的 requires_grad 状态
        # for name, param in true_score.named_parameters():
        #     true_score_grad_state[name] = param.requires_grad
        #     param.requires_grad_(False)  # 临时关闭梯度计算
        #
        # for name, param in fake_score.named_parameters():
        #     fake_score_grad_state[name] = param.requires_grad
        #     param.requires_grad_(False)  # 临时关闭梯度计算
        with torch.no_grad():
        #     true_score = true_score.eval().to(device)  # 冻结梯度但保留计算图
        #     fake_score = fake_score.eval().to(device)  # 冻结梯度但保留计算图

            y_fake = zero123_denoise(unet=fake_score, sampler=sampler,
                                     sampler_noise=sampler_img, noise=noise,
                                     context=context, img_context=img_context, timesteps=timesteps,
                                     guidance_scale=guidance_scale, device=device)

            # Denoised fake images (track generator gradient) under pretrained score network, using guidance scale: kappa4=cfg_eval_real
            y_real = zero123_denoise(unet=true_score, sampler=sampler,
                                     sampler_noise=sampler_img, noise=noise,
                                     context=context, img_context=img_context, timesteps=timesteps,
                                     guidance_scale=guidance_scale, device=device)
        # y_fake = y_fake.requires_grad_()
        # y_real = y_real.requires_grad_()
        end_time = time.time()
        print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
        print('return image real fake time used: ', end_time-start_time, 'batch = ', b)
        print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
            # y_real = y_real.to(device_3)
            # y_fake = y_fake.to(device_3)
            # print('y_real requires grad? ', y_real.requires_grad)
            # print('y_fake requires grad? ', y_fake.requires_grad)
            # print(f'y_real device {y_real.device}, y_fake device {y_fake.device}, sampler_img device {sampler_img.device}')
            # sampler_img = sampler_img.to(device_3)
            # print('y_real - y_fake common obj', y_real - y_fake)

        with torch.no_grad():
            nan_mask_images = torch.isnan(sampler_img).flatten(start_dim=1).any(dim=1)
            nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
            nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
            nan_mask = nan_mask_images | nan_mask_y_real | nan_mask_y_fake

        # Check if there are any NaN values present
        if nan_mask.any():
            # Invert the nan_mask to get a mask of samples without NaNs
            non_nan_mask = ~nan_mask
            # Filter out samples with NaNs from y_real and y_fake
            sampler_img = sampler_img[non_nan_mask]
            y_real = y_real[non_nan_mask]
            y_fake = y_fake[non_nan_mask]
            # print('masked y_real requires grad? ', y_real.requires_grad)
            # print('masked y_fake requires grad? ', y_fake.requires_grad)

        with torch.no_grad():
            weight_factor = abs(sampler_img.to(torch.float32) - y_real.to(torch.float32)).mean(
                # dim=[1, 2, 3], keepdim=True).clip(min=0.00001)
                dim=[1, 2, 3], keepdim=True).clip(min=0.01)

        print(f'y_real requires_grad:{y_real.requires_grad}, y_fake requires grad:{y_fake.requires_grad}, sampler_img requires grad:{sampler_img.requires_grad}')
        # weight_factor = torch.tensor(weight_factor, dtype=torch.float32, requires_grad=True)
        if alpha == 1:
            g_remain_loss = (y_real - y_fake) * (y_fake - sampler_img) / weight_factor
        else:
            g_remain_loss = (y_real - y_fake) * (
                    (y_real - sampler_img) - alpha * (y_real - y_fake)) / weight_factor

        print(f"Loss requires grad: {g_remain_loss.requires_grad}")
        print(f"Loss computation graph exists: {g_remain_loss.grad_fn is not None}")
        g_remain_loss = g_remain_loss.sum().mul(loss_scaling_G)

        if (~nan_mask).sum().item() > 0:
            g_remain_loss = g_remain_loss * g_remain_coef  # 确保所有张量都启用了 requires_grad
            if not g_remain_loss.requires_grad:
                g_remain_loss.requires_grad_(True)
            print(f'sum of y_real:{y_real.sum()}, y_fake:{y_fake.sum()}, sampler_img:{sampler_img.sum()}')
            print(f'y_real requires_grad:{y_real.requires_grad}, y_fake requires grad:{y_fake.requires_grad}, sampler_img requires grad:{sampler_img.requires_grad}')
            # for param in G.parameters():
            #     print(param.requires_grad)
            print('g_remain_loss: ', g_remain_loss)
            g_remain_loss.backward()
            torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1.0)

            # # 恢复原始的 requires_grad 状态
            # for name, param in true_score.named_parameters():
            #     param.requires_grad_(true_score_grad_state[name])
            #
            # for name, param in fake_score.named_parameters():
            #     param.requires_grad_(fake_score_grad_state[name])

        else:
            print('g_remain_loss no backward')

        g_remain_loss = g_remain_loss.detach().cpu().item()
        g_remain_loss_print += g_remain_loss
        g_remain_loss_list.append(g_remain_loss)
        # for param in fake_score.parameters():
        #     if param.grad is not None:
        #         param.grad.zero_()
        # for param in true_score.parameters():
        #     if param.grad is not None:
        #         param.grad.zero_()

        del y_real, y_fake, sampler_img, g_remain_loss

        if g_forget_coef > 0:
            # generator network有要遗忘的内容
            forget_img, override_img, angle = next(forget_iterator)
            # forget_img = img_preprocess(forget_img, carvekit_model)
            # override_img = img_preprocess(override_img, carvekit_model)
            with torch.no_grad():
                forget_context, forget_img_context = get_context(model=model, cam_angle=angle, input_image=forget_img,
                                                             guidance_scale=guidance_scale, device=device)
                override_context, override_img_context = get_context(model=model, cam_angle=angle, input_image=override_img,
                                                                 guidance_scale=guidance_scale, device=device)

            sampler_img = zero123_sampler(unet=G, device=device, sampler=sampler, guidance_scale=guidance_scale,
                                          context=forget_context, img_context=forget_img_context,
                                          return_images=False, num_steps=32, train_sampler=True, num_steps_eval=32)

            noise = torch.randn(img_shape)
            timesteps = torch.randint(tmin, tmax, (len(forget_context),), device=device, dtype=torch.long)
            # Compute loss for generator
            # Denoised fake images (track generator gradient) under fake score network, using guidance scale: kappa2=kappa3=cfg_eval_fake
            # 重点理解！！！
            with torch.no_grad():
                y_fake = zero123_denoise(unet=fake_score, device=device, sampler=sampler,
                                         sampler_noise=sampler_img, noise=noise,
                                         context=forget_context, img_context=forget_img_context, timesteps=timesteps,
                                         guidance_scale=guidance_scale)

                # Denoised fake images (track generator gradient) under pretrained score network, using guidance scale: kappa4=cfg_eval_real
                y_real = zero123_denoise(unet=true_score, sampler=sampler,
                                         sampler_noise=sampler_img, noise=noise,
                                         context=override_context, img_context=override_img_context,
                                         timesteps=timesteps,
                                         guidance_scale=guidance_scale, device=device)
                # y_real = y_real.to(device_3)
                # y_fake = y_fake.to(device_3)
                # print(f'y_real device {y_real.device}, y_fake device {y_fake.device}, sampler_img device {sampler_img.device}')
                # sampler_img = sampler_img.to(device_3)
                # print('y_real - y_fake forget', y_real - y_fake)

            with torch.no_grad():
                nan_mask_images = torch.isnan(sampler_img).flatten(start_dim=1).any(dim=1)
                nan_mask_y_real = torch.isnan(y_real).flatten(start_dim=1).any(dim=1)
                nan_mask_y_fake = torch.isnan(y_fake).flatten(start_dim=1).any(dim=1)
                nan_mask = nan_mask_images | nan_mask_y_real | nan_mask_y_fake
                # print(y_real.shape)
                # print('nan in y_real_fake', len(nan_mask))
                # print('len yfake', y_fake.flatten(start_dim=1).any(dim=1))

            # Check if there are any NaN values present
            if nan_mask.any():
                # Invert the nan_mask to get a mask of samples without NaNs
                non_nan_mask = ~nan_mask
                # Filter out samples with NaNs from y_real and y_fake
                sampler_img = sampler_img[non_nan_mask]
                y_real = y_real[non_nan_mask]
                y_fake = y_fake[non_nan_mask]
                print('len masked y fake', y_fake)

            with torch.no_grad():
                weight_factor = abs(sampler_img.to(torch.float32) - y_real.to(torch.float32)).mean(
                    dim=[1, 2, 3], keepdim=True).clip(min=0.01)

            # weight_factor = torch.tensor(weight_factor, dtype=torch.float32, requires_grad=True)
            if alpha == 1:
                g_forget_loss = (y_real - y_fake) * (y_fake - sampler_img) / weight_factor
            else:
                g_forget_loss = (y_real - y_fake) * (
                        (y_real - sampler_img) - alpha * (y_real - y_fake)) / weight_factor

            g_forget_loss = g_forget_loss.sum().mul(loss_scaling_G)

            if (~nan_mask).sum().item() > 0:
                g_forget_loss = g_forget_loss * g_remain_coef  # 确保所有张量都启用了 requires_grad
                if not g_forget_loss.requires_grad:
                    g_forget_loss.requires_grad_(True)
                print(f'sum of y_real:{y_real.sum()}, y_fake:{y_fake.sum()}, sampler_img:{sampler_img.sum()}')
                g_forget_loss.backward()
                torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1.0)
                print('g_forget_loss: ', g_forget_loss)
            else:
                print('g_forget_loss no backward')

            g_forget_loss = g_forget_loss.detach().cpu().item()
            g_forget_loss_print += g_forget_loss
            g_forget_loss_list.append(g_forget_loss)
            # for param in fake_score.parameters():
            #     if param.grad is not None:
            #         param.grad.zero_()
            # for param in true_score.parameters():
            #     if param.grad is not None:
            #         param.grad.zero_()
            # fake_score_optimizer.zero_grad()

            del y_real, y_fake, sampler_img, g_forget_loss

            training_stats.report('G_Loss/remain_loss', g_remain_loss_print)
            training_stats.report('G_Loss/forget_loss', g_forget_loss_print)

            G.eval().requires_grad_(False)

            # Update generator
            for param in G.parameters():
                if param.grad is not None:
                    torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)

            # # Apply gradient clipping under fp16 to prevent suddern divergence
            # if dtype == torch.float16 and (~nan_mask).sum().item() > 0:
            #     torch.nn.utils.clip_grad_value_(G.parameters(), 1)

            g_optimizer.step()

            # if ema_halflife_kimg > 0:
            #     # Update EMA.
            #     ema_halflife_nimg = ema_halflife_kimg * 1000
            #     if ema_rampup_ratio is not None:
            #         ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * ema_rampup_ratio)
            #     ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
            #
            #     for p_ema, p_true_score in zip(G_ema.parameters(), G.parameters()):
            #         with torch.no_grad():
            #             p_ema.copy_(p_true_score.detach().lerp(p_ema, ema_beta))
            # else:
            #     G_ema = G

            torch.cuda.empty_cache()
            gc.collect()

            end_time = time.time()
            cur_nimg += batch_size
            done = (cur_nimg >= total_kimg * 1000)

            print(f'cur_tick: {cur_tick}, cur_nimg: {cur_nimg}, time spent: {end_time - start_time}')
            # if cur_nimg != 50 and cur_nimg != 5:
                # if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + 10):
                #     continue


            # Print status line, accumulating the same information in training_stats.
            tick_end_time = time.time()
            fields = []
            fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"]
            fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<9.1f}"]
            # fields += [
            #     f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"]
            # fields += [
            #     f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"]
            # fields += [
            #     f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"]
            # fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"]
            fields += [
                f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2 ** 30):<6.2f}"]
            fields += [
                f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2 ** 30):<6.2f}"]
            fields += [
                f"reserved {training_stats.report0('Resources/peak_gpu_mem_reserved_gb', torch.cuda.max_memory_reserved(device) / 2 ** 30):<6.2f}"]
            fields += [
                f"loss_fake_score_remain {training_stats.report0('fake_score_Loss/remain_loss', sg_remain_loss_print):<6.2f}"]
            fields += [
                f"loss_fake_score_forget {training_stats.report0('fake_score_Loss/forget_loss', sg_forget_loss_print):<6.2f}"]
            fields += [f"loss_G_remain {training_stats.report0('G_Loss/remain_loss', g_remain_loss_print):<6.2f}"]
            fields += [f"loss_G_forget {training_stats.report0('G_Loss/forget_loss', g_forget_loss_print):<6.2f}"]
            torch.cuda.reset_peak_memory_stats()
            dist.print0(' '.join(fields))

            # # plot loss
            # fig, axes = plt.subplots(2, 2, figsize=(12, 8))  # 2行2列子图
            # axes = axes.flatten()  # 将子图数组展平以方便索引
            #
            # # 绘制每个子图
            # axes[0].plot(sg_remain_loss_list, label='sg remain loss', color='r', marker='o')
            # axes[0].set_title("sg remain loss")
            # axes[0].set_xlabel("Epoch")
            # axes[0].set_ylabel("Loss")
            # axes[0].grid(True)
            # axes[0].set_ylim(-10000, 10000)
            #
            # axes[1].plot(sg_forget_loss_list, label='sg forget loss', color='g', marker='s')
            # axes[1].set_title("sg forget loss")
            # axes[1].set_xlabel("Epoch")
            # axes[1].set_ylabel("Loss")
            # axes[1].grid(True)
            # axes[1].set_ylim(-10000, 10000)
            #
            # axes[2].plot(g_remain_loss_list, label='g remain loss', color='b', marker='^')
            # axes[2].set_title("g remain loss")
            # axes[2].set_xlabel("Epoch")
            # axes[2].set_ylabel("Loss")
            # axes[2].grid(True)
            # axes[2].set_ylim(-10000, 10000)
            #
            # axes[3].plot(g_forget_loss_list, label='g forget loss', color='m', marker='d')
            # axes[3].set_title("g forget loss")
            # axes[3].set_xlabel("Epoch")
            # axes[3].set_ylabel("Loss")
            # axes[3].grid(True)
            # axes[3].set_ylim(-10000, 10000)
            # plt.tight_layout()
            # # plt.figure(figsize=(10, 6))
            # # plt.plot(sg_remain_loss_list, label='sg remain loss', marker='o')
            # # plt.plot(sg_forget_loss_list, label='sg forget loss', marker='s')
            # # plt.plot(g_remain_loss_list, label='g remain loss', marker='^')
            # # plt.plot(g_forget_loss_list, label='g forget loss', marker='d')
            #
            # save_path = os.path.join(run_dir, "loss_curves.png")
            # plt.savefig(save_path, dpi=300, bbox_inches='tight')
            # print(f"Loss curves saved to {save_path}")

            # Check for abort.
            if (not done) and dist.should_stop():
                done = True
                dist.print0()
                dist.print0('Aborting...')

            # if (snapshot_ticks is not None) and (
            #         done or cur_tick % snapshot_ticks == 0 or cur_tick in [1, 2, 4, 10, 20, 30, 40, 50, 60, 70, 80, 90,
            #                                                                100]):
            # if (cur_nimg % 5 == 0 and cur_nimg >= 50) or (cur_nimg == 50):
            if (cur_nimg % 5 == 0) or (cur_nimg == 0):

                dist.print0('Exporting sample images...')
                grid_size = (32, 32)
                forget_img, override_img, angle = next(forget_iterator)
                context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle,
                                                   guidance_scale=3, device=device)
                with torch.no_grad():
                    images = zero123_sampler(unet=G, device=device, sampler=sampler,
                                             context=context, img_context=img_context, guidance_scale=3,
                                             return_images=True, num_steps=32, train_sampler=False,
                                             num_steps_eval=32)

                images = images.cpu().permute(0, 2, 3, 1).clamp(0, 1).numpy()
                fig, axes = plt.subplots(2, 2, figsize=(8, 8))
                for i, ax in enumerate(axes.flat):
                    ax.imshow(images[i])
                    ax.axis('off')
                plt.tight_layout()
                # plt.savefig(f'test_grid_{cur_nimg}.png')
                plt.savefig(os.path.join(run_dir, f'test_grid_{cur_nimg}.png'))
                # valForgetAngleDataset = ObjectTOForgetAngleDataset(
                #     forget_image_path='zero123_dataset/object_to_forget_angle/image/02.png',
                #     override_image_path='zero123_dataset/object_to_forget_angle/image/real_plane.png',
                #     angles_csv_path='zero123_dataset/object_to_forget_angle/angles_to_override.csv',
                #     transform=None,
                #     val=True
                # )
                # for i in range(num_angles):
                #     forget_img, override_img, angle = next(forget_iterator)
                #     # angle = [0., 0., 0.]
                #     # input_im = img_preprocess(forget_img, carvekit_model=carvekit_model)
                #     # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NEW CC_PROJECTION!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                #     context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle,
                #                                        guidance_scale=3, device=device)
                #     override_context, override_img_context = get_context(model=model, input_image=override_img,
                #                                                          cam_angle=angle, guidance_scale=3,
                #                                                          device=device)
                #     # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
                #     with torch.no_grad():
                #         sampler_img = zero123_sampler(unet=G, sampler=sampler, device=device, guidance_scale=3,
                #                                       context=context, img_context=img_context,
                #                                       init_timesteps=init_timestep,
                #                                       return_images=True, num_steps=32, train_sampler=False,
                #                                       num_steps_eval=32)
                #     image_tensor = torch.clamp((sampler_img[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
                #     sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
                #     img = Image.fromarray(sample.astype(np.uint8))
                #     img.save(os.path.join(
                #         run_dir, f'fakes_{alpha:03f}_{cur_nimg:06d}_{angle}.png'))
                #     del sampler_img

                # ctxs, img_ctxs = [], []
                # # for i in range(6):
                # forget_img, override_img, angle = next(forget_iterator)
                # # img = img_preprocess(forget_img, carvekit_model)
                # # angle = [0., 0., 0.]
                # context, img_context = get_context(model=model, input_image=forget_img, cam_angle=angle, guidance_scale=guidance_scale, device=device)
                # for i in range(len(forget_img)):
                #     ctxs.append(context[[i]])
                #     img_ctxs.append(img_context[i])
                #
                # if dist.get_rank() == 0:
                #     for num_steps_eval in [32]:
                #         # While the generator is primarily trained to generate images in a single step, it can also be utilized in a multi-step setting during evaluation.
                #         # To do: Distill a multi-step generator that is optimized for multi-step settings
                #         with torch.no_grad():
                #             images = [zero123_sampler(unet=G, device=device, sampler=sampler,
                #                                       context=ctx, img_context=img_ctx, guidance_scale=guidance_scale,
                #                                       return_images=True, num_steps=1, train_sampler=False,
                #                                       num_steps_eval=num_steps_eval)
                #                       for ctx, img_ctx in zip(ctxs, img_ctxs)]
                #
                #         images = torch.cat(images).cpu().numpy()
                #         save_image_grid(img=images, fname=os.path.join(
                #             run_dir, f'fakes_{alpha:03f}_{cur_nimg:06d}.png'),
                #             drange=[-1, 1], grid_size=grid_size)
                #         # for image in images:
                #         #     # image = images[0]
                #         #     image_tensor = torch.clamp((image[0] + 1.0) / 2.0, min=0.0, max=1.0).cpu()
                #         #     sample = 255.0 * rearrange(image_tensor.detach().numpy(), 'c h w -> h w c')
                #         #     img = Image.fromarray(sample.astype(np.uint8))
                #         #     img.save(os.path.join(
                #         #         run_dir, f'fakes_{alpha:03f}_{cur_nimg:06d}_{angle}.png'))
                #
                #         # save_image_grid(img=images, fname=os.path.join(
                #         #     run_dir, f'fakes_{alpha:03f}_{cur_nimg // 1000:06d}_{num_steps_eval:d}.png'),
                #         #                 drange=[-1, 1], grid_size=(1, 1))
                #
                #     del images

                if cur_nimg >= 50:
                    G_save = copy.deepcopy(G).to('cpu')
                    data = dict(ema=G_save)
                    # for key, value in data.items():
                    #     if isinstance(value, torch.nn.Module):
                    #         from collections import OrderedDict
                    #0
                    #         value_state_dict = OrderedDict([(k, v.detach().cpu()) for k, v in value.state_dict().items()])
                    #         unet_cpu_copy.load_state_dict(value_state_dict)
                    #         data[key] = unet_cpu_copy
                    #         del value_state_dict

                    if dist.get_rank() == 0:
                        torch.save(data, os.path.join(run_dir, f'network-snapshot-mod5-{alpha:03f}-{cur_nimg :06d}.pkl'))
                        # save_data(data=data,
                        #           fname=os.path.join(run_dir, f'network-snapshot-mod5-{alpha:03f}-{cur_nimg :06d}.pkl'))

                    del data, G_save  # conserve memory

            # if (state_dump_ticks is not None) and (
            #         done or cur_tick % state_dump_ticks == 0) and cur_tick != 0 and dist.get_rank() == 0:
            #     dist.print0(f'saving checkpoint: training-state-{cur_nimg // 1000:06d}.pt')
            #     save_pt(pt=dict(fake_score=fake_score, G=G, G_ema=G_ema,
            #                     fake_score_optimizer_state=fake_score_optimizer.state_dict(),
            #                     g_optimizer_state=g_optimizer.state_dict()),
            #             fname=os.path.join(run_dir, f'training-state-{cur_nimg // 1000:06d}.pt'))

            # Update logs.
            training_stats.default_collector.update()
            if dist.get_rank() == 0:
                if stats_jsonl is None:
                    append_line(jsonl_line=json.dumps(
                        dict(training_stats.default_collector.as_dict(), timestamp=time.time())) + '\n',
                                fname=os.path.join(run_dir, f'stats_{alpha:03f}.jsonl'))

            dist.update_progress(cur_nimg // 1000, total_kimg)

            # Update state.
            cur_tick += 1
            tick_start_nimg = cur_nimg
            tick_start_time = time.time()
            maintenance_time = tick_start_time - tick_end_time
            if done:
                break

        # Done.
        dist.print0()
        dist.print0('Exiting...')

# ----------------------------------------------------------------------------
