import os
from tqdm import tqdm, trange
from random import randint
import math
import numpy as np
import random
from collections import defaultdict, OrderedDict
import torch
import torch.nn.functional as F
from torchvision import io
from torchvision.transforms import ToPILImage, ToTensor
from PIL import Image
from einops import rearrange
import pickle
import scipy
import imageio
import glob
import cv2
import kornia
import open3d as o3d
import shutil
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.multimodal.clip_score import CLIPScore

from arguments import ModelParams, PipelineParams, OptimizationParams
from gaussian_renderer import render
from scene.gaussian_model_local import EditPro_Render_local 
from scene.gaussian_model import EditPro_Render
from scene.gaussian_model import GaussianModel
from scene import Scene

from utils.loss_utils import l1_loss, ssim, compute_depth_loss
from lpipsPyTorch import lpips
from utils.image_utils import psnr, colorize
from utils.utils_poses.align_traj import align_ate_c2b_use_a2b
from utils.utils_poses.comp_ate import compute_rpe, compute_ATE
from utils.flow_viz import flow_to_image
from utils.perceptual import PerceptualLoss
from utils.sam import LangSAMTextSegmentor
from utils.utils_poses.ATE.align_utils import alignTrajectory
from utils.utils_poses.lie_group_helper import SO3_to_quat, convert3x4_4x4
from utils.annotator_utils import resize_image, HWC3
from utils.graphics_utils import BasicPointCloud, focal2fov, procrustes

from kornia.geometry.depth import depth_to_3d, depth_to_normals
from scripts.generate_flow import warp_flow

from guidance.preprocess import prep, get_timesteps
from guidance.run_pnp_tpe import run_tpe
from guidance.run_pnp_spe import run_spe
from guidance.util import *
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler

import pdb
import clip

from .trainer import GaussianTrainer
from .losses import Loss, compute_scale_and_shift, SSIM_V2, L_TV

from copy import copy
from utils.vis_utils import interp_poses_bspline, smooth_poses_spline, draw_poses
from guidance.pnp_utils import *
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, StableDiffusionPipeline, DDIMInverseScheduler
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
from transformers import CLIPTextModel, CLIPTokenizer, logging

try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_FOUND = True
except ImportError:
    TENSORBOARD_FOUND = False


def contruct_pose(poses):
    n_trgt = poses.shape[0]
    for i in range(n_trgt-1, 0, -1):
        poses = torch.cat(
            (poses[:i], poses[[i-1]]@poses[i:]), 0)
    return poses

def get_timesteps(scheduler, num_inference_steps, strength, device):
    # get the original timestep using init_timestep
    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

    t_start = max(num_inference_steps - init_timestep, 0)
    timesteps = scheduler.timesteps[t_start:]

    return timesteps, num_inference_steps - t_start

class EditProGaussianTrainer(GaussianTrainer):
    def __init__(self, data_root, model_cfg, pipe_cfg, optim_cfg):
        super().__init__(data_root, model_cfg, pipe_cfg, optim_cfg)
        self.model_cfg = model_cfg
        self.pipe_cfg = pipe_cfg
        self.optim_cfg = optim_cfg

        self.gs_render = EditPro_Render(white_background=False,
                                   view_dependent=model_cfg.view_dependent,)
        self.gs_render_local = EditPro_Render_local(white_background=False,
                                         view_dependent=model_cfg.view_dependent,)
        self.use_mask = self.pipe_cfg.use_mask # False
        self.use_mono = self.pipe_cfg.use_mono # True
        self.near = 0.01 
        
        self.mask_list = []
        self.edit_frames = {}
        self.perceptual_loss = PerceptualLoss().eval().to('cuda')
        self.text_segmentor = LangSAMTextSegmentor().to('cuda')
        self.setup_losses()

        self.enable_spatial = self.model_cfg.enable_spatial

        self.device = torch.device("cuda")
        self.ssim_loss = SSIM_V2()
        self.loss_TV = L_TV()
        self.lpips_loss = LearnedPerceptualImagePatchSimilarity().to(self.device)
        self.clip_metrics = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(self.device)

        # data type align
        self.pil_to_tensor = ToTensor()
        self.tensor_to_pil = ToPILImage()

        if optim_cfg.use_edit and len(model_cfg.seg_prompt) > 0:
            self.update_mask()

        if self.enable_spatial:
            if model_cfg.sd_version == '2.1':
                model_key = "stabilityai/stable-diffusion-2-1-base"
            elif model_cfg.sd_version == '2.0':
                model_key = "stabilityai/stable-diffusion-2-base"
            elif model_cfg.sd_version == '1.5' or model_cfg.sd_version == 'ControlNet':
                model_key = "runwayml/stable-diffusion-v1-5"
            elif model_cfg.sd_version == 'depth':
                model_key = "stabilityai/stable-diffusion-2-depth"
            elif model_cfg.sd_version =='xl':
                model_key = "stabilityai/stable-diffusion-xl-base-1.0"

            toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
            toy_scheduler.set_timesteps(self.model_cfg.steps)
            self.timesteps_to_save, self.num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=self.model_cfg.steps,
                                                                        strength=1.0,
                                                                        device='cuda')

            self.sd_unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16",
                                                        torch_dtype=torch.float16).to(self.device)
            self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")

            pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=torch.float16).to("cuda")

            self.vae = pipe.vae
            self.tokenizer = pipe.tokenizer
            self.text_encoder = pipe.text_encoder
            self.unet = pipe.unet

            self.text_embeds = self.get_text_embeds(model_cfg.guidance["prompt"], model_cfg.guidance["negative_prompt"])
            self.pnp_guidance_embeds = self.get_text_embeds("", "").chunk(2)[0]
            

    def setup_losses(self):
        self.loss_func = Loss(self.optim_cfg)

    def train_step_local_3d(self,
                   gs_render,
                   viewpoint_cam,
                   iteration,
                   pipe,
                   optim_opt,
                   colors_precomp=None,
                   update_gaussians=True,
                   update_cam=True,
                   update_distort=False,
                   densify=True,
                   prev_gaussians=None,
                   use_reproject=False,
                   use_matcher=False,
                   ref_fidx=None,
                   reset=True,
                   reproj_loss=None,
                   local=False,
                   edit=False,
                   **kwargs,
                   ):
        if viewpoint_cam.foreground_mask is None:
            local = False
        kwargs['local'] = local
        
        other = []
        time_shift = None


        render_pkg = gs_render.render(
            viewpoint_cam,
            other=other,
            time_shift=time_shift,
            compute_cov3D_python=pipe.compute_cov3D_python,
            convert_SHs_python=pipe.convert_SHs_python,
            override_color=colors_precomp)

        image, viewspace_point_tensor, visibility_filter, radii = (render_pkg["image"],
                                                                   render_pkg["viewspace_points"],
                                                                   render_pkg["visibility_filter"],
                                                                   render_pkg["radii"])
        # Loss
        gt_image = viewpoint_cam.original_image.cuda()

        loss_dict = {}

        loss_dict = self.compute_loss(render_pkg, viewpoint_cam,
                                        pipe, optim_opt, iteration,
                                        use_reproject, use_matcher,
                                        ref_fidx, **kwargs)
            
        loss = loss_dict['loss']
        loss.backward()

        with torch.no_grad():
            if local:
                psnr_train = psnr(image * (1- viewpoint_cam.foreground_mask[0:1,...]), gt_image * (1- viewpoint_cam.foreground_mask[0:1,...])).mean().double()
            else:
                psnr_train = psnr(image, gt_image).mean().double()
            self.just_reset = False
            if iteration < optim_opt.densify_until_iter and densify:
                # Keep track of max radii in image-space for pruning
                try:
                    gs_render.gaussians.max_radii2D[visibility_filter] = torch.max(gs_render.gaussians.max_radii2D[visibility_filter],
                                                                                   radii[visibility_filter])
                except:
                    pdb.set_trace()
                gs_render.gaussians.add_densification_stats(
                    viewspace_point_tensor, visibility_filter)

                if iteration > optim_opt.densify_from_iter and iteration % optim_opt.densification_interval == 0:
                    size_threshold = 20 if iteration > optim_opt.opacity_reset_interval else None
                    self.gs_render.gaussians.densify_and_prune(optim_opt.densify_grad_threshold, 0.005,
                                                               gs_render.radius, size_threshold, optim_opt.densify_grad_t_threshold)

                if iteration % optim_opt.opacity_reset_interval == 0 and reset and iteration < optim_opt.reset_until_iter:
                    gs_render.gaussians.reset_opacity()
                    self.just_reset = True

            if update_gaussians:
                gs_render.gaussians.optimizer.step()
                gs_render.gaussians.optimizer.zero_grad(set_to_none=True)
            if getattr(gs_render.gaussians, "camera_optimizer", None) is not None and update_cam:
                current_fidx = gs_render.gaussians.seq_idx
                gs_render.gaussians.camera_optimizer[current_fidx].step()
                gs_render.gaussians.camera_optimizer[current_fidx].zero_grad(
                    set_to_none=True)
           
        return loss_dict, render_pkg, psnr_train

    def train_step_edit_v2(self,
                   gs_render,
                   viewpoint_cam,
                   iteration,
                   pipe,
                   optim_opt,
                   colors_precomp=None,
                   update_gaussians=True,
                   update_cam=True,
                   update_distort=False,
                   densify=True,
                   prev_gaussians=None,
                   use_reproject=False,
                   use_matcher=False,
                   ref_fidx=None,
                   reset=True,
                   reproj_loss=None,
                   local=False,
                   edit=True,
                   **kwargs,
                   ):
        # Render
        render_pkg = gs_render.render(
            viewpoint_cam,
            compute_cov3D_python=pipe.compute_cov3D_python,
            convert_SHs_python=pipe.convert_SHs_python,
            override_color=colors_precomp)

        image, viewspace_point_tensor, visibility_filter, radii = (render_pkg["image"],
                                                                   render_pkg["viewspace_points"],
                                                                   render_pkg["visibility_filter"],
                                                                   render_pkg["radii"])

        gt_edited_image = self.edit_frames[viewpoint_cam.uid]

        if "depth" in render_pkg:
            depth = render_pkg["depth"]
            depth[depth < self.near] = self.near
            kwargs['depth_pred'] = depth

        loss_dict = self.loss_func(render_pkg["image"].double(), gt_edited_image.double(), **kwargs)

        if 'depth' in render_pkg and optim_opt.lambda_flow and viewpoint_cam.uid:
            if getattr(viewpoint_cam, 'fwd_flow', None) is not None:
                fwd_flow = viewpoint_cam.fwd_flow.cuda().permute(2, 0, 1)
                fwd_flow_mask = viewpoint_cam.fwd_flow_mask.cuda()
                render_flow_fwd = gs_render.render_flow(viewpoint_cam, 
                                                        self.scene.time_delta, 
                                                        compute_cov3D_python=pipe.compute_cov3D_python,
                                                        convert_SHs_python=pipe.convert_SHs_python,
                                                        override_color=colors_precomp)['image'][:2, ...]
                fwd_flow = fwd_flow / (torch.max(torch.sqrt(torch.square(fwd_flow).sum(-1))) + 1e-5)
                render_flow_fwd = render_flow_fwd / (torch.max(torch.sqrt(torch.square(render_flow_fwd).sum(-1))) + 1e-5)
                M = fwd_flow_mask.unsqueeze(0)
                fwd_flow_loss = torch.sum(torch.abs(fwd_flow - render_flow_fwd) * M) / (torch.sum(M) + 1e-8) / fwd_flow.shape[-1]
                loss_dict['loss'] += optim_opt.lambda_flow * fwd_flow_loss.item()
                loss_dict['fwd_flow_loss'] = fwd_flow_loss

            if getattr(viewpoint_cam, 'bwd_flow', None) is not None:
                bwd_flow = viewpoint_cam.bwd_flow.permute(2, 0, 1).cuda()
                bwd_flow_mask = viewpoint_cam.bwd_flow_mask.cuda()
                render_flow_bwd = gs_render.render_flow(viewpoint_cam, 
                                                        -self.scene.time_delta, 
                                                        compute_cov3D_python=pipe.compute_cov3D_python,
                                                        convert_SHs_python=pipe.convert_SHs_python,
                                                        override_color=colors_precomp)['image'][:2, ...]
                bwd_flow = bwd_flow / (torch.max(torch.sqrt(torch.square(bwd_flow).sum(-1))) + 1e-5)
                render_flow_bwd = render_flow_bwd / (torch.max(torch.sqrt(torch.square(render_flow_bwd).sum(-1))) + 1e-5)
                M = bwd_flow_mask.unsqueeze(0)
                bwd_flow_loss = torch.sum(torch.abs(bwd_flow - render_flow_bwd) * M) / (torch.sum(M) + 1e-8) / bwd_flow.shape[-1]
                loss_dict['loss'] += optim_opt.lambda_flow * bwd_flow_loss.item()
                loss_dict['bwd_flow_loss'] = bwd_flow_loss

      
        if self.enable_spatial and viewpoint_cam.uid and iteration % 50 == 0:
            poses = []
            images = []

            for view_id in range(optim_opt.n_views):

                uid_prev = randint(0, 10)
                if viewpoint_cam.uid > 10:
                    pose_ref_1 = self.gs_render.gaussians.get_RT(viewpoint_cam.uid).inverse().detach().cpu().numpy() #c2w
                    pose_ref_2 = self.gs_render.gaussians.get_RT(uid_prev).inverse().detach().cpu().numpy()
                else:
                    pose_ref_1 = self.gs_render.gaussians.get_RT(viewpoint_cam.uid).inverse().detach().cpu().numpy() #c2w
                    pose_ref_2 = self.gs_render.gaussians.get_RT(viewpoint_cam.uid-uid_prev).inverse().detach().cpu().numpy()
                pose_ref = np.stack([pose_ref_1, pose_ref_2])
               
                degree = np.random.randint(30, 120)
                pose = interp_poses_bspline(pose_ref, 10, [viewpoint_cam.uid-1, viewpoint_cam.uid], degree)[5] #c2w
                cur_cam  = self.load_viewpoint_cam(viewpoint_cam.uid, pose=torch.tensor(pose).inverse().clone().detach().cpu(),)
                render_novel_pkg = gs_render.render(cur_cam,
                                                compute_cov3D_python=pipe.compute_cov3D_python,
                                                convert_SHs_python=pipe.convert_SHs_python,
                                                override_color=colors_precomp)
                images.append(render_novel_pkg["image"].unsqueeze(0))
                poses.append(pose)
             
                if self.model_cfg.downsample:
                    image_interpolate = torch.nn.functional.interpolate(render_novel_pkg["image"].unsqueeze(0), size=((render_novel_pkg["image"].shape[1]//8) * 4, (render_novel_pkg["image"].shape[2]//8) * 4), mode='bilinear')
                    image_ref = torch.nn.functional.interpolate(gt_edited_image.unsqueeze(0), size=((render_novel_pkg["image"].shape[1]//8) * 4, (render_novel_pkg["image"].shape[2]//8) * 4), mode='bilinear')
                else:
                    image_interpolate = render_novel_pkg["image"].unsqueeze(0)
                    image_ref = gt_edited_image.unsqueeze(0)

                combined_imgs = torch.cat([image_ref, image_interpolate], dim=0)
                self.get_ddim_inversion(combined_imgs, viewpoint_cam, depth_maps=torch.stack([self.mono_depth[viewpoint_cam.uid].unsqueeze(0), render_novel_pkg['depth']]) if self.model_cfg.sd_version == 'depth' else None)
                # img_ode_interpolate = run_spe(self.model_cfg)
                img_ode_interpolate = self.spatial_propagation(combined_imgs, viewpoint_cam, depth_maps=torch.stack([self.mono_depth[viewpoint_cam.uid], render_pkg["depth"].squeeze(0)]) if self.model_cfg.sd_version == 'depth' else None)
                
                image_interpolate = torch.nn.functional.interpolate(image_interpolate, img_ode_interpolate.shape[-2:], mode='bilinear')
                loss_dict['loss_spatial_l1']= torch.nan_to_num(l1_loss(image_interpolate, img_ode_interpolate))
                # loss_dict['lpips_loss_spatial'] = 0.2 * self.lpips_loss(image_interpolate, img_ode_interpolate)

                loss_dict['loss'] += loss_dict['loss_spatial_l1'] 

        loss = loss_dict['loss']
        loss.backward()

        with torch.no_grad():
            psnr_train = psnr(image, gt_edited_image).mean().double()
            self.just_reset = False
            if iteration < optim_opt.densify_until_iter and densify:
                # Keep track of max radii in image-space for pruning
                try:
                    gs_render.gaussians.max_radii2D[visibility_filter] = torch.max(gs_render.gaussians.max_radii2D[visibility_filter],
                                                                                   radii[visibility_filter])
                except:
                    pdb.set_trace()
                gs_render.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)

                if iteration > optim_opt.densify_from_iter and iteration % optim_opt.densification_interval == 0:
                    size_threshold = 20 if iteration > optim_opt.opacity_reset_interval else None
                    self.gs_render.gaussians.densify_and_prune(optim_opt.densify_grad_threshold, 0.005,
                                                               gs_render.radius, size_threshold, optim_opt.densify_grad_t_threshold)

                if iteration % optim_opt.opacity_reset_interval == 0 and reset and iteration < optim_opt.reset_until_iter:
                    gs_render.gaussians.reset_opacity()
                    self.just_reset = True

            if update_gaussians:
                gs_render.gaussians.optimizer.step()
                gs_render.gaussians.optimizer.zero_grad(set_to_none=True)
            if getattr(gs_render.gaussians, "camera_optimizer", None) is not None and update_cam:
                current_fidx = gs_render.gaussians.seq_idx
                gs_render.gaussians.camera_optimizer[current_fidx].step()
                gs_render.gaussians.camera_optimizer[current_fidx].zero_grad(
                    set_to_none=True)
           

        return loss_dict, render_pkg, psnr_train

    def init_two_view(self, view_idx_1, view_idx_2, pipe, optim_opt):
        # prepare data
        self.loss_func.depth_loss_type = "invariant"
        cam_info, pcd, viewpoint_cam = self.prepare_data(view_idx_1,
                                                         orthogonal=True,
                                                         down_sample=True)
        radius = np.linalg.norm(pcd.points, axis=1).max()

        # Initialize gaussians
        self.gs_render.reset_model()
        self.gs_render.init_model(pcd,)
        # self.gs_render.init_model(num_pts=300_000,)
        self.gs_render.gaussians.init_RT_seq(self.seq_len)
        self.gs_render.gaussians.set_seq_idx(view_idx_1)
        self.gs_render.gaussians.rotate_seq = False

        # Fit relative pose
        print(f"optimizing frame {view_idx_1:03d}")
        optim_opt.iterations = 1000
        optim_opt.densify_from_iter = optim_opt.iterations + 1 # 1001
        progress_bar = tqdm(range(optim_opt.iterations),
                            desc="Training progress")
        self.gs_render.gaussians.training_setup(optim_opt, fix_pos=True,)

        self.get_edited_frames(viewpoint_cam)
        self.mono_depth[view_idx_1] = self.predict_depth(self.edit_frames[view_idx_1].permute(1,2,0).cpu().numpy() * 255.0)

        for iteration in range(1, optim_opt.iterations+1):
            # Update learning rate
            self.gs_render.gaussians.update_learning_rate(iteration)
            loss, rend_dict, psnr_train = self.train_step_edit_v2(self.gs_render,
                                                          viewpoint_cam, iteration,
                                                          pipe, optim_opt,
                                                          depth_gt=self.mono_depth[view_idx_1],
                                                          update_gaussians=True,
                                                          update_cam=False, 
                                                          )
            if iteration % 10 == 0:
                progress_bar.set_postfix({"PSNR": f"{psnr_train:.{2}f}",
                                          "Number points": f"{self.gs_render.gaussians.get_xyz.shape[0]}"})
                progress_bar.update(10)
            if iteration == optim_opt.iterations:
                progress_bar.close()


        self.pcd_stack = []
        self.pcd_stack.append(self.gs_render.gaussians.get_xyz.detach())
        model_params = self.gs_render.gaussians.capture()
        return model_params

    def add_view_v2(self, view_idx, view_idx_prev, reverse=False):
        # Initialize gaussians
        self.loss_func.depth_loss_type = "invariant"
        pipe = copy(self.pipe_cfg)
        optim_opt = copy(self.optim_cfg)
        # prepare data
        cam_info, pcd, viewpoint_cam = self.prepare_data(view_idx_prev,
                                                         orthogonal=True,
                                                         down_sample=True)
        radius = np.linalg.norm(pcd.points, axis=1).max()
        self.gs_render_local.reset_model()
        self.gs_render_local.init_model(pcd)
        # Fit current gaussian
        optim_opt.iterations = 1000
        optim_opt.densify_from_iter = optim_opt.iterations + 1
        progress_bar = tqdm(range(optim_opt.iterations),
                            desc="Local Prev Training progress")
        self.gs_render_local.gaussians.training_setup(
            optim_opt, fix_pos=True,)
        
        for iteration in range(1, optim_opt.iterations+1):
            # Update learning rate
            self.gs_render_local.gaussians.update_learning_rate(iteration)
            loss, rend_dict, psnr_train = self.train_step_local_3d(self.gs_render_local,
                                                          viewpoint_cam, iteration,
                                                          pipe, optim_opt,
                                                          depth_gt=self.mono_depth[view_idx_prev],
                                                          update_gaussians=True,
                                                          update_cam=False,
                                                          updata_distort=False,
                                                          densify=False,
                                                          local=True,
                                                          )
            if psnr_train > 35 and iteration > 500:
                progress_bar.close()
                break

            if iteration % 10 == 0:
                progress_bar.set_postfix({"PSNR": f"{psnr_train:.{2}f}",
                                          "Number points": f"{self.gs_render.gaussians.get_xyz.shape[0]}"})
                progress_bar.update(10)
            if iteration == optim_opt.iterations:
                progress_bar.close()

        print(f"optimizing frame {view_idx:03d}")
        viewpoint_cam_ref = self.load_viewpoint_cam(view_idx,
                                                    load_depth=True)
        optim_opt.iterations = 300
        optim_opt.densify_from_iter = optim_opt.iterations + 1
        self.gs_render_local.gaussians.init_RT(None)
        self.gs_render_local.gaussians.training_setup_fix_position(
            optim_opt, gaussian_rot=False)

        progress_bar = tqdm(range(optim_opt.iterations),
                            desc="Local Ref Training progress")

        for iteration in range(1, optim_opt.iterations+1):
            # Update learning rate
            self.gs_render_local.gaussians.update_learning_rate(iteration)
            loss, rend_dict_ref, psnr_train = self.train_step_local_3d(self.gs_render_local,
                                                              viewpoint_cam_ref, iteration,
                                                              pipe, optim_opt,
                                                              densify=False,
                                                              local=True,
                                                              depth_gt=self.mono_depth[viewpoint_cam_ref.uid],
                                                              )
            if iteration % 10 == 0:
                progress_bar.set_postfix({"PSNR": f"{psnr_train:.{2}f}",
                                          "Number points": f"{self.gs_render.gaussians.get_xyz.shape[0]}"})
                progress_bar.update(10)
            if iteration == optim_opt.iterations:
                progress_bar.close()

        local_model_params = self.gs_render_local.gaussians.capture()

        # pcd under view_idx_prev frame
        pcd = self.gs_render_local.gaussians._xyz.detach()
        rel_pose = self.gs_render_local.gaussians.get_RT().detach()
        pose = rel_pose @ self.gs_render.gaussians.get_RT(
            view_idx_prev).detach()
        self.gs_render.gaussians.update_RT_seq(pose, view_idx)
        self.gs_render.gaussians.rotate_seq = False

        pipe.convert_SHs_python = self.gs_render.gaussians.rotate_seq

        
        if self.just_reset:
            num_iterations = 500
            self.just_reset = False
            for iteration in range(1, num_iterations):
                fidx = randint(0, view_idx_prev)
                self.global_iteration += 1
                self.gs_render.gaussians.update_learning_rate(
                    self.global_iteration)
                viewpoint_cam = self.load_viewpoint_cam(fidx,
                                                        pose=self.gs_render.gaussians.get_RT(
                                                            fidx).detach().cpu(),
                                                        load_depth=True)
                loss, rend_dict_ref, psnr_train = self.train_step_edit_v2(self.gs_render,
                                                                  viewpoint_cam,
                                                                  self.global_iteration,
                                                                  pipe, self.optim_cfg,
                                                                  update_gaussians=True,
                                                                  update_cam=True,
                                                                  depth_gt=self.mono_depth[fidx],
                                                                  update_distort=False,
                                                                  )


        num_iterations = self.single_step
        if max(view_idx, view_idx_prev) >= min(int(self.seq_len * 0.8), self.seq_len-5):
            num_iterations = 1000
        if min(view_idx, view_idx_prev) < int(self.single_step // 100):
            num_iterations = 100

        progress_bar = tqdm(range(num_iterations), desc="Global Training progress")

        for iteration in range(1, num_iterations+1):

            last_frame = max(0, view_idx//2)
            if random.random() < 0.7:
                fidx = randint(last_frame, view_idx)
            else:
                fidx = randint(0, last_frame)

            self.global_iteration += 1
            if self.gs_render.gaussians.rotate_seq:
                self.gs_render.gaussians.set_seq_idx(fidx)
            viewpoint_cam = self.load_viewpoint_cam(fidx,
                                                    pose=self.gs_render.gaussians.get_RT(
                                                        fidx).detach().cpu()
                                                    if not self.gs_render.gaussians.rotate_seq
                                                    else None,
                                                    load_depth=True)
            if viewpoint_cam.uid not in self.edit_frames:
                self.get_edited_frames(viewpoint_cam)
                self.mono_depth[viewpoint_cam.uid] = self.predict_depth(self.edit_frames[viewpoint_cam.uid].permute(1,2,0).cpu().numpy() * 255.0)
            # Update learning rate
            self.gs_render.gaussians.update_learning_rate(
                self.global_iteration)

            loss, rend_dict_ref, psnr_train = self.train_step_edit_v2(self.gs_render,
                                                              viewpoint_cam,
                                                              self.global_iteration,
                                                              pipe, self.optim_cfg,
                                                              update_gaussians=True,
                                                              update_cam=True,
                                                              depth_gt=self.mono_depth[fidx],
                                                              update_distort=self.pipe_cfg.distortion,
                                                              )

            if self.global_iteration % 1000 == 0:
                self.gs_render.gaussians.oneupSHdegree()

            if iteration % 10 == 0:
                progress_bar.set_postfix({"PSNR": f"{psnr_train:.{2}f}",
                                          "Number points": f"{self.gs_render.gaussians.get_xyz.shape[0]}"})
                progress_bar.update(10)

            if iteration == num_iterations:
                progress_bar.close()

        return pcd, local_model_params, loss

    def create_pcd_from_render(self, render_dict, viewpoint_cam):
        intrinsics = torch.from_numpy(viewpoint_cam.intrinsics).float().cuda()
        depth = render_dict["depth"].squeeze()
        image = render_dict["image"]
        pts = depth_to_3d(depth[None, None],
                          intrinsics[None],
                          normalize_points=False)
        points = pts.squeeze().permute(1, 2, 0).detach().cpu().reshape(-1, 3).numpy()
        colors = image.permute(1, 2, 0).detach().cpu().reshape(-1, 3).numpy()
        pcd_data = o3d.geometry.PointCloud()
        pcd_data.points = o3d.utility.Vector3dVector(points)
        pcd_data.colors = o3d.utility.Vector3dVector(colors)
        pcd_data = pcd_data.farthest_point_down_sample(num_samples=30_000)
        colors = np.asarray(pcd_data.colors, dtype=np.float32)
        points = np.asarray(pcd_data.points, dtype=np.float32)
        normals = np.asarray(pcd_data.normals, dtype=np.float32)
        pcd = BasicPointCloud(points, colors, normals)
        return pcd

    def train_from_progressive(self, ):
        
        pipe = copy(self.pipe_cfg)
        optim_opt = copy(self.optim_cfg)
        self.single_step = 500
        num_iterations = self.single_step * (self.seq_len // 10) * 10
        self.optim_cfg.iterations = num_iterations # 65000
        self.optim_cfg.position_lr_max_steps = num_iterations # 65000
        self.optim_cfg.opacity_reset_interval = num_iterations // 10 # 6500
        self.optim_cfg.densify_until_iter = num_iterations # 65000
        self.optim_cfg.reset_until_iter = int(num_iterations * 0.8) # 52000
        self.optim_cfg.densify_from_iter = 1000
        self.optim_cfg.densify_from_iter = self.single_step

        if pipe.expname == "":
            expname = "progressive"
        else:
            expname = pipe.expname
        pipe.convert_SHs_python = True
        optim_opt = copy(self.optim_cfg)
        result_path = f"output/{expname}/{self.category}_{self.seq_name}"
        self.model_cfg.result_path = result_path
        os.makedirs(result_path, exist_ok=True)
        if TENSORBOARD_FOUND:
            self.tb_writer = SummaryWriter(result_path)
        else:
            self.tb_writer = None
            print("Tensorboard not available: not logging progress")
        
        if optim_opt.use_edit:
            self.model_cfg.n_frames = len(self.all_data)
            self.model_cfg.device = 'cuda'
    
            self.model_cfg.latents_path = os.path.join(result_path, f'sd_{self.model_cfg.sd_version}', f'steps_{self.model_cfg.steps}', 'latents')
            self.model_cfg.edited_path = os.path.join(result_path, f'pnp_SD_{self.model_cfg.sd_version}', f'{self.model_cfg.guidance["prompt"]}')
            self.model_cfg.spe_path = os.path.join(result_path, f'spe', f'steps_{self.model_cfg.steps}', 'latents')

            if not os.path.exists(self.model_cfg.latents_path):
                prep(self.model_cfg, self.optim_cfg, result_path)
            if self.model_cfg.n_frames % 10 == 0:
                self.model_cfg.batch_size_2 = 10
            elif self.model_cfg.n_frames % 8 == 0:
                self.model_cfg.batch_size_2 = 8
            elif self.model_cfg.n_frames % 5 == 0:
                self.model_cfg.batch_size_2 = 5
            elif self.model_cfg.n_frames % 4 == 0:
                self.model_cfg.batch_size_2 = 4
            elif self.model_cfg.n_frames % 2 == 0:
                self.model_cfg.batch_size_2 = 2
            else:
                self.model_cfg.batch_size_2 = 1
            if not os.path.exists(self.model_cfg.edited_path) and optim_opt.use_edit: 
                os.makedirs(self.model_cfg.edited_path, exist_ok=True)
                run_tpe(self.model_cfg)

            if self.enable_spatial:
                os.makedirs(self.model_cfg.spe_path, exist_ok=True)
                pnp_f_t = int(self.model_cfg.guidance["n_timesteps_sp"] * self.model_cfg.pnp_params["pnp_f_t"])
                pnp_attn_t = int(self.model_cfg.guidance["n_timesteps_sp"] * self.model_cfg.pnp_params["pnp_attn_t"])
                self.init_method(conv_injection_t=pnp_f_t, qk_injection_t=pnp_attn_t)
      
        pose_dict = dict()
        poses_gt = []
        for seq_data in self.data:
            if self.data_type == "co3d":
                R, t, _, _, _ = self.load_camera(seq_data)
            else:
                R = seq_data.R.transpose()
                t = seq_data.T
            pose = np.eye(4)
            pose[:3, :3] = R
            pose[:3, 3] = t
            poses_gt.append(torch.from_numpy(pose))
        pose_dict["poses_gt"] = torch.stack(poses_gt)
        max_frame = self.seq_len
        start_frame = 1
        end_frame = max_frame
        skip_init = False
        skip_refine = False

        if self.model_cfg.model_path != "":
            if "init" in self.model_cfg.model_path:
                skip_init = True
                skip_refine = False

            self.gs_render.gaussians.restore(
                torch.load(self.model_cfg.model_path), self.optim_cfg)
            pose_dict = torch.load(
                self.model_cfg.model_path.replace('chkpnt', 'pose'))
            poses_pred = pose_dict["poses_pred"]
            self.match_results = pose_dict["match_results"]
            self.gs_render.gaussians.init_RT_seq(self.seq_len, poses_pred)
            self.gs_render.gaussians.rotate_seq = True

        #     self.construct_point_o3d(self.gs_render, end_frame, result_path)

        os.makedirs(f"{result_path}/pose", exist_ok=True)

        num_eppch = 1
        reverse = False
        for epoch in range(num_eppch):
            if not self.pipe_cfg.refine:
                gauss_params = self.init_two_view(
                    0, end_frame, pipe, copy(self.optim_cfg))
            if not skip_init:
                self.global_iteration = 0
                self.edit_iteration = 0
                optim_opt = copy(self.optim_cfg)
                self.gs_render.gaussians.rotate_seq = True
                self.match_results = OrderedDict()
                for fidx in range(start_frame, end_frame):
                    pcd_new, local_gauss_params, loss_dict = self.add_view_v2(
                        fidx, fidx-1)
                    
                    self.gs_render.gaussians.rotate_seq = False

                    viewpoint_cam = self.load_viewpoint_cam(fidx,
                                                            pose=self.gs_render.gaussians.get_RT(
                                                                fidx).detach().cpu(),
                                                            )
                    render_dict = self.gs_render.render(viewpoint_cam,
                                                        compute_cov3D_python=pipe.compute_cov3D_python,
                                                        convert_SHs_python=pipe.convert_SHs_python)
                    # gt_image = viewpoint_cam.original_image.cuda()
                    gt_image = self.edit_frames[viewpoint_cam.uid]
                    psnr_train = psnr(render_dict["image"], gt_image).mean().double()

                    if self.tb_writer:
                        for key, value in loss_dict.items():
                            self.tb_writer.add_scalar(f'train/{key}', value, self.global_iteration)
                        self.tb_writer.add_scalar('train/psnr', psnr_train, self.global_iteration)
                    
                    if self.global_iteration in optim_opt.test_iterations:
                        pose_dict["poses_pred"] = []
                        for idx in range(self.seq_len):
                            pose = self.gs_render.gaussians.get_RT(idx)
                            pose_dict["poses_pred"].append(pose.detach().cpu())

                        pose_dict["poses_pred"] = torch.stack(pose_dict["poses_pred"])
                        pose_dict["poses_gt"] = torch.stack(poses_gt)
                        pose_dict["match_results"] = self.match_results
                        print("\n[ITER {}] Saving Checkpoint".format(self.global_iteration))
                        torch.save(self.gs_render.gaussians.capture(), result_path + "/chkpnt" + str(self.global_iteration) + ".pth")
                        torch.save(pose_dict, result_path + "/pose" + str(self.global_iteration) + ".pth")
        
                    print(
                        'Frames {:03d}/{:03d}, PSNR : {:.04f}'.format(fidx, self.seq_len-1, psnr_train))
                    self.visualize(render_dict,
                                   f"{result_path}/train/{self.global_iteration:06d}_{fidx:03d}.png",
                                   gt_image=gt_image, gt_depth=self.mono_depth[fidx], save_ply=False)

                with torch.no_grad():
                    psnr_test = 0.0
                    ssim_test = 0.0
                    lpips_test = 0.0
                    pose_dict["poses_pred"] = []
                    self.render_depth = OrderedDict()
                    self.gs_render.gaussians.rotate_seq = False
                    self.gs_render.gaussians.rotate_xyz = False


                    for val_idx in range(end_frame):
                        viewpoint_cam = self.load_viewpoint_cam(val_idx,
                                                                pose=self.gs_render.gaussians.get_RT(
                                                                    val_idx).detach().cpu(),
                                                                )
                        render_dict = self.gs_render.render(viewpoint_cam,
                                                            compute_cov3D_python=pipe.compute_cov3D_python,
                                                            convert_SHs_python=pipe.convert_SHs_python)
                        self.render_depth[val_idx] = render_dict["depth"]
                        gt_image = self.edit_frames[val_idx]
                        psnr_test += psnr(render_dict["image"], gt_image).mean().double()
                        ssim_test += ssim(render_dict["image"].double(), gt_image.double()).double()
                        lpips_test += lpips(render_dict["image"].float(), gt_image.float(), net_type='vgg').double().item()
                        self.visualize(render_dict,
                                       f"{result_path}/eval/ep{epoch:02d}_{self.global_iteration:06d}_{val_idx:03d}.png",
                                       gt_image=gt_image, save_ply=False)

                    print('Number of {:03d} to {:03d} frames: PSNR : {:.04f} SSIM : {:.04f} LPIPS : {:.04f}'.format(
                        start_frame, end_frame, psnr_test / (end_frame), ssim_test / (end_frame), lpips_test / (end_frame)))
                    
                    with open(f"{result_path}/eval/eval.txt", 'w') as f:
                        f.write('PSNR : {:.04f}, SSIM : {:.04f}, LPIPS : {:.04f}'.format(
                                psnr_test / end_frame,
                                ssim_test / end_frame,
                                lpips_test / end_frame))
                        f.close()

                    for idx in range(self.seq_len):
                        pose = self.gs_render.gaussians.get_RT(idx)
                        pose_dict["poses_pred"].append(pose.detach().cpu())

                pose_dict["poses_pred"] = torch.stack(pose_dict["poses_pred"])
                pose_dict["poses_gt"] = torch.stack(poses_gt)
                pose_dict["match_results"] = self.match_results
                torch.save(
                    pose_dict, f"{result_path}/pose/ep{epoch:02d}_init.pth")
                os.makedirs(f"{result_path}/chkpnt", exist_ok=True)
                torch.save(self.gs_render.gaussians.capture(),
                           f"{result_path}/chkpnt/ep{epoch:02d}_init.pth")

            else:
                self.global_iteration = num_iterations
                self.gs_render.gaussians.rotate_seq = True
                self.gs_render.gaussians.init_RT_seq(self.seq_len,
                                                     pose_dict["poses_pred"])

    def compute_loss(self,
                     render_dict,
                     viewpoint_cam,
                     pipe_opt,
                     optim_opt,
                     iteration,
                     use_reproject=False,
                     use_matcher=False,
                     ref_fidx=None,
                     **kwargs):
        loss = 0.0
        if "image" in render_dict:
            image = render_dict["image"]
            gt_image = viewpoint_cam.original_image.cuda()
        if "depth" in render_dict:
            depth = render_dict["depth"]
            depth[depth < self.near] = self.near
            fidx = viewpoint_cam.uid
            kwargs['depth_pred'] = depth

        if kwargs['local']:
            kwargs['foreground_mask'] = viewpoint_cam.foreground_mask

        loss_dict = self.loss_func(image, gt_image, **kwargs)
    
        return loss_dict

    def visualize(self, render_pkg, filename, gt_image=None, gt_depth=None, save_ply=False, vis_depth=True, timestamp=0.0):
        os.makedirs(os.path.dirname(filename), exist_ok=True)
        if "depth" in render_pkg and vis_depth:
            rend_depth = Image.fromarray(
                colorize(render_pkg["depth"].detach().cpu().numpy(),
                         cmap='magma_r')).convert("RGB")
            if gt_depth is not None:
                gt_depth = Image.fromarray(
                    colorize(gt_depth.detach().cpu().numpy(),
                             cmap='magma_r')).resize((rend_depth.size[0], rend_depth.size[1])).convert("RGB")
                rend_depth = Image.fromarray(np.hstack([np.asarray(gt_depth),
                                                        np.asarray(rend_depth)]))
            rend_depth.save(filename.replace(".png", "_depth.png"))
        if "acc" in render_pkg:
            rend_acc = Image.fromarray(
                colorize(render_pkg["acc"].detach().cpu().numpy(),
                         cmap='magma_r')).convert("RGB")
            rend_acc.save(filename.replace(".png", "_acc.png"))

        rend_img = Image.fromarray(
            np.asarray(render_pkg["image"].detach().cpu().permute(1, 2, 0).numpy()
                       * 255.0, dtype=np.uint8)).convert("RGB")

        if gt_image is not None:
            gt_image = Image.fromarray(
                np.asarray(
                    gt_image.permute(1, 2, 0).cpu().numpy() * 255.0,
                    dtype=np.uint8)).resize((rend_img.size[0], rend_img.size[1])).convert("RGB")
            rend_img = Image.fromarray(np.hstack([np.asarray(gt_image),
                                                  np.asarray(rend_img)]))
        rend_img.save(filename)

        if save_ply:
            points = self.gs_render.gaussians._xyz.detach().cpu().numpy()
            pcd_data = o3d.geometry.PointCloud()
            pcd_data.points = o3d.utility.Vector3dVector(points)
            pcd_data.colors = o3d.utility.Vector3dVector(np.ones_like(points))
            self.gs_render.gaussians.save_ply(filename.replace('.png', '.ply'), timestamp)

    def compute_clip(self, render_edit, prompt):
        with torch.no_grad():
            cos_sim = self.clip_metrics(render_edit, prompt).detach()
            total_cos = abs(cos_sim.item()) / 100.
        return total_cos

    def compute_warperror(self, viewpoint_cam, img1, img2):
        fwd_flow = viewpoint_cam.fwd_flow.cpu().numpy()
        fwd_mask = viewpoint_cam.fwd_flow_mask.cpu().numpy()
        warp_img2 = warp_flow(img2.permute(1,2,0).cpu().numpy() * 255., fwd_flow)
        warp_img2 = cv2.cvtColor(np.uint8(warp_img2), cv2.COLOR_BGR2RGB)
        img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), size=(img2.shape[-2], img2.shape[-1]), mode='bilinear').squeeze()
        img1 = cv2.cvtColor(np.uint8(img1.permute(1,2,0).cpu().numpy() * 255.), cv2.COLOR_BGR2RGB)
      
        warpssim = ssim(torch.tensor((warp_img2 / 255.) * fwd_mask[:,:, np.newaxis]).permute(-1,0,1).cuda().double(), torch.tensor((img1 / 255.) * fwd_mask[:,:, np.newaxis]).permute(-1,0,1).cuda().double()).mean().double()

        return warpssim

    @torch.no_grad()
    def ddim_inversion(self, cond, latent_frames, idx, batch_size=2, depth_maps=None):
        timesteps = reversed(self.scheduler.timesteps)
        timesteps_to_save = self.timesteps_to_save if self.timesteps_to_save is not None else timesteps
        if self.model_cfg.sd_version == 'depth':
            depth_maps = torch.nn.functional.interpolate(
                                            depth_maps,
                                            size=(latent_frames.shape[-2], latent_frames.shape[-1]),
                                            mode="bicubic",
                                            align_corners=False,
                                            )
        for i, t in enumerate(timesteps):
            # source_latents = self.load_source_latents_t(t, self.model_cfg.latents_path)[idx]
            for b in range(0, latent_frames.shape[0], batch_size):
                x_batch = latent_frames[b:b + batch_size]
                model_input = x_batch
                cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
                if self.model_cfg.sd_version == 'depth':
                    depth_maps = torch.cat([depth_maps[b: b + batch_size]])
                    model_input = torch.cat([x_batch, depth_maps],dim=1)
                                                                    
                alpha_prod_t = self.scheduler.alphas_cumprod[t]
                alpha_prod_t_prev = (
                    self.scheduler.alphas_cumprod[timesteps[i - 1]]
                    if i > 0 else self.scheduler.final_alpha_cumprod
                )

                mu = alpha_prod_t ** 0.5
                mu_prev = alpha_prod_t_prev ** 0.5
                sigma = (1 - alpha_prod_t) ** 0.5
                sigma_prev = (1 - alpha_prod_t_prev) ** 0.5

                eps = self.sd_unet(model_input, t, encoder_hidden_states=cond_batch).sample
                pred_x0 = (x_batch - sigma_prev * eps) / mu_prev
                latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps
                # latent_frames[0] = source_latents
            max_t = t

            if t in timesteps_to_save:
                torch.save(latent_frames, os.path.join(self.model_cfg.spe_path, f'noisy_latents_{t}.pt'))
        torch.save(latent_frames, os.path.join(self.model_cfg.spe_path, f'noisy_latents_{t}.pt'))
            
        return latent_frames, max_t

    def get_ddim_inversion(self, imgs, viewpoint_cam, depth_maps=None):
        if viewpoint_cam.uid > 3:
            idx = (viewpoint_cam.uid - 3) // 8 + viewpoint_cam.uid + 1
        else:
            idx = viewpoint_cam.uid

        latents = self.encode_imgs(imgs.to(torch.float16).to(self.device), deterministic=True).to(torch.float16).to(self.device)
        self.scheduler.set_timesteps(self.model_cfg.guidance["n_timesteps_sp"])
        # cond = self.get_text_embeds(self.model_cfg.inversion_prompt, "")[1].unsqueeze(0)
        cond = self.get_text_embeds("", "")[1].unsqueeze(0)
        inverted_x, max_t = self.ddim_inversion(cond, latents, [idx],  depth_maps=depth_maps.to(torch.float16).to(self.device))
    
        return inverted_x, max_t

    def spatial_propagation(self, imgs, masks, inverted_x, viewpoint_cam, max_t, batch_size=2, depth_maps=None):
        # denoised_latents = []
        self.scheduler.set_timesteps(self.model_cfg.guidance["n_timesteps_sp"], device=self.device)
        latents = self.encode_imgs(imgs.to(torch.float16).to(self.device), deterministic=True).to(torch.float16).to(self.device)
        eps = self.get_ddim_eps(latents, range(2)).to(torch.float16).to(self.device)

        noisy_latents = self.scheduler.add_noise(latents, eps, self.scheduler.timesteps[0])

        if self.model_cfg.sd_version == 'depth':
            depth_maps = torch.nn.functional.interpolate(
                                            depth_maps.unsqueeze(1),
                                            size=(latents.shape[-2], latents.shape[-1]),
                                            mode="bicubic",
                                            align_corners=False,
                                            )
        indices = torch.arange(2)
        for i, t in enumerate(self.scheduler.timesteps):
            register_pivotal(self.unet, True)
            self.denoise_step(noisy_latents[0].unsqueeze(0), t, [0], depth_maps=depth_maps)
            register_pivotal(self.unet, False)
            for idx, b in enumerate(range(0, noisy_latents.shape[0], batch_size)):
                register_batch_idx(self.unet, idx)
                denoised_latents=self.denoise_step(noisy_latents[b:b + batch_size], t, indices[b:b + batch_size], depth_maps=depth_maps)

        decoded_latents = self.decode_latents(denoised_latents)

        return decoded_latents[1].unsqueeze(0)

    def init_method(self, conv_injection_t, qk_injection_t):
        self.qk_injection_timesteps = self.scheduler.timesteps[:qk_injection_t] if qk_injection_t >= 0 else []
        self.conv_injection_timesteps = self.scheduler.timesteps[:conv_injection_t] if conv_injection_t >= 0 else []
        register_extended_attention_pnp(self, self.qk_injection_timesteps)
        register_conv_injection(self, self.conv_injection_timesteps)
        set_tokenflow(self.unet)
        print("SPE init complete!")

    @torch.no_grad()
    def encode_imgs(self, imgs, batch_size=10, deterministic=True):
        imgs = 2 * imgs - 1
        latents = []
        for i in range(0, len(imgs), batch_size):
            posterior = self.vae.encode(imgs[i:i + batch_size]).latent_dist
            latent = posterior.mean if deterministic else posterior.sample()
            latents.append(latent * 0.18215)
        latents = torch.cat(latents)
        return latents

    def get_ddim_eps(self, latent, indices):
        noisest = max([int(x.split('_')[-1].split('.')[0]) for x in glob.glob(os.path.join(self.model_cfg.spe_path, f'noisy_latents_*.pt'))])
        latents_path = os.path.join(self.model_cfg.spe_path, f'noisy_latents_{noisest}.pt')
        noisy_latent = torch.load(latents_path)[indices].cuda()
        alpha_prod_T = self.scheduler.alphas_cumprod[noisest]
        mu_T, sigma_T = alpha_prod_T ** 0.5, (1 - alpha_prod_T) ** 0.5
        eps = (noisy_latent - mu_T * latent) / sigma_T
        return eps

    def load_source_latents_t(self, t, latents_path):
        latents_t_path = os.path.join(latents_path, f'noisy_latents_{t}.pt')
        assert os.path.exists(latents_t_path), f'Missing latents at t {t} path {latents_t_path}'
        latents = torch.load(latents_t_path)
        return latents

    @torch.no_grad()
    def denoise_step(self, x, t, indices, masks=None, depth_maps=None):
        # register the time step and features in pnp injection modules
        source_latents = self.load_source_latents_t(t, self.model_cfg.spe_path)[indices]

        latent_model_input = torch.cat([source_latents] + ([x] * 2))
        if self.model_cfg.sd_version == 'depth':
            latent_model_input = torch.cat([latent_model_input, torch.cat([depth_maps[0].unsqueeze(0)] * 3)], dim=1).to(torch.float16)

        register_time(self, t.item())

        text_embed_input = torch.cat([self.pnp_guidance_embeds.repeat(len(indices), 1, 1),
                                      torch.repeat_interleave(self.text_embeds, len(indices), dim=0)])

        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input)['sample']

        _, noise_pred_uncond, noise_pred_cond = noise_pred.chunk(3)
        noise_pred = noise_pred_uncond + self.model_cfg.guidance["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)

        denoised_latent = self.scheduler.step(noise_pred, t, x)['prev_sample']

        return denoised_latent

    def prepare_mask_and_mask_latents(self, mask, height, width):
        mask[mask < 0.5] = 0
        mask[mask >= 0.5] = 1

        mask_latents = torch.nn.functional.interpolate(mask, size=(height, width)).to(device=self.device, dtype=torch.float16)
        return mask_latents

    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt, batch_size=1):
        text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                    truncation=True, return_tensors='pt')
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]

        uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                      return_tensors='pt')

        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

        text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size)
        return text_embeddings

    @torch.no_grad()
    def decode_latents(self, latents, batch_size=10):
        latents = 1 / 0.18215 * latents
        imgs = []
        for i in range(0, len(latents), batch_size):
            imgs.append(self.vae.decode(latents[i:i + batch_size]).sample)
        imgs = torch.cat(imgs)
        imgs = (imgs / 2 + 0.5).clamp(0, 1)
        return imgs

    def warp_flow_loss(self, viewpoint_cam_prev, img1, img2):
        fwd_flow = viewpoint_cam_prev.fwd_flow.cpu().numpy()
        fwd_mask = viewpoint_cam_prev.fwd_flow_mask.cpu().numpy()
        warp_img2 = warp_flow(img2.permute(1,2,0).detach().cpu().numpy() * 255., fwd_flow)
        warp_img2 = cv2.cvtColor(np.uint8(warp_img2), cv2.COLOR_BGR2RGB)
        img1 = torch.nn.functional.interpolate(img1.unsqueeze(0), size=(img2.shape[-2], img2.shape[-1]), mode='bilinear').squeeze()
        img1 = cv2.cvtColor(np.uint8(img1.detach().permute(1,2,0).cpu().numpy() * 255.), cv2.COLOR_BGR2RGB)
        err = np.sum(np.abs(warp_img2 - img1) * fwd_mask[:,:, np.newaxis]) / (np.sum(fwd_mask) + 1e-8) / ((fwd_flow.shape[0] + fwd_flow.shape[1]) / 2.)
        return torch.tensor(err).float().requires_grad_(True)

    def update_mask(self, save_name="mask"):
        cache_dir = f"output/{self.pipe_cfg.expname}/{self.category}_{self.seq_name}"
        print(f"Segment with prompt: {self.model_cfg.seg_prompt}")
        mask_cache_dir = os.path.join(
            cache_dir, self.model_cfg.seg_prompt)
        gs_mask_path = os.path.join(mask_cache_dir, "gs_mask.pt")

        if not os.path.exists(mask_cache_dir):
            os.makedirs(mask_cache_dir)
        
            weights = torch.zeros_like(self.gs_render.gaussians._opacity)
            weights_cnt = torch.zeros_like(self.gs_render.gaussians._opacity, dtype=torch.int32)
            print("Segmentation with prompt:", f"{self.model_cfg.seg_prompt}")
            for id in tqdm(self.all_data):
                cur_path = os.path.join(mask_cache_dir, f"{id.image_name}.png")
                cur_path_viz = os.path.join(mask_cache_dir, f"viz_{id.image_name}.png")
                cur_cam = id
                mask = self.text_segmentor(id.original_image.unsqueeze(0).permute(0, 2, 3, 1), self.model_cfg.seg_prompt)[0].to('cuda')

                mask_to_save = (
                        mask[0]
                        .cpu()
                        .detach()[..., None]
                        .repeat(1, 1, 3)
                        .numpy()
                        .clip(0.0, 1.0)
                        * 255.0
                ).astype(np.uint8)
                kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(12,12))
                mask_to_save = cv2.dilate(mask_to_save, kernel)
                cv2.imwrite(cur_path, mask_to_save)

                mask = torch.tensor(mask_to_save / 255.0)[:,:,0]

                self.mask_list.append(mask)

                masked_image = id.original_image.detach().clone().permute(1,2,0)
                masked_image[mask.bool()] *= 0.3
                masked_image_to_save = (
                        masked_image.cpu().detach().numpy().clip(0.0, 1.0) * 255.0
                ).astype(np.uint8)
                masked_image_to_save = cv2.cvtColor(
                    masked_image_to_save, cv2.COLOR_RGB2BGR
                )
                cv2.imwrite(cur_path_viz, masked_image_to_save)

        else:
            print("load cache")
            for filename in tqdm(sorted(os.listdir(mask_cache_dir))):
                if filename.startswith('viz'):
                    continue
                mask = cv2.imread(os.path.join(mask_cache_dir, filename))
                self.mask_list.append(torch.tensor(mask / 255.).permute(-1,0,1).cuda())
    
    def get_edited_frames(self, viewpoint_cam):
        if viewpoint_cam.uid > 3:
            idx = (viewpoint_cam.uid - 3) // 8 + viewpoint_cam.uid + 1
        else:
            idx = viewpoint_cam.uid
        edited_frame_path = os.path.join(self.model_cfg.edited_path, 'img_ode', viewpoint_cam.image_name + '.png')
        edited_frame= Image.open(edited_frame_path).convert('RGB')
        edited_frame = np.asarray(edited_frame) / 255.0
        edited_frame = torch.from_numpy(edited_frame).permute(2, 0, 1).float().clamp(0.0, 1.0).cuda()
        edited_frame = torch.nn.functional.interpolate(edited_frame.unsqueeze(0), viewpoint_cam.original_image.shape[1:], mode='bilinear').squeeze(0)
        if len(self.mask_list) > 0 and self.model_cfg.mask_attr == 1:    
            edited_frame = edited_frame * self.mask_list[idx] + viewpoint_cam.original_image * (1 - self.mask_list[idx])
        elif len(self.mask_list) > 0 and self.model_cfg.mask_attr == -1:    
            edited_frame = viewpoint_cam.original_image * self.mask_list[idx] + edited_frame * (1 - self.mask_list[idx])
        self.edit_frames[viewpoint_cam.uid] = edited_frame



