import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning as pl
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_debug

import models

import systems
from systems.base import BaseSystem
from systems.criterions import PSNR, binary_cross_entropy, SSIM

from torch.profiler import profile, record_function, ProfilerActivity
from pytorch_memlab import LineProfiler, MemReporter

from systems.utils import parse_optimizer, parse_scheduler, update_module_step
from utils import plot
from utils.pose import select_k_uniform_points

import numpy as np
import copy

from tqdm import tqdm
import os, json
import imageio
import cv2
import math

from utils.loss import ssim_loss
from utils.lpipsPyTorch.modules.lpips import LPIPS

from models.gaussian_splatting.utils.loss_utils import ssim
from models.gaussian_splatting.utils.general_utils import build_rotation

@systems.register('restarting-3dgs-system')
class RestartingSystem(BaseSystem):
    """
    Two ways to print to console:
    1. self.print: correctly handle progress bar
    2. rank_zero_info: use the logging module
    """

    def __init__(self, config):
        super().__init__(config)
        self.automatic_optimization = False

    def configure_optimizers(self): 
        ret = self.configure_optimizers_3dgs()
        return None

    def configure_optimizers_3dgs(self):
        optim = parse_optimizer(self.config.system.optimizer, self.model)
        ret = {
            'optimizer': optim,
        }
        ret.update({
            'lr_scheduler': parse_scheduler(self.config.system.scheduler, optim),
        })

        optims = [optim]

        return optims, [ret['lr_scheduler']]

    def prepare(self):
        self.criterions = {
            'psnr': PSNR().to(self.rank), 
            'ssim': SSIM().to(self.rank)
        }

        #self.model.near, self.model.far = self.config.dataset.near_plane, self.config.dataset.far_plane

    def parse_n(self, n): 
        if n == 0: 
            n = len(self.dataset.all_images)
        elif 0 < n < 1:
            n = int(n * len(self.dataset.all_images))
        elif self.config.dataset.n_images >= 1: 
            n = n
        else: 
            raise ValueError(f"Invalid n images")

        return n

    def parse_n_images(self): 
        self.n_images = self.parse_n(self.config.dataset.n_images)
        if self.n_images <= 0: 
            self.n_images = 1

    def parse_max_num_new_imgs(self): 
        if self.config.nbv.use is False: 
            return
        else: 
            self.config.nbv.max_num_new_imgs = self.parse_n(self.config.nbv.max_num_new_imgs)

    def reset_radius(self):
        cam_centers = self.dataset.all_c2w[:,:3,3]
        def get_center_and_diag(cam_centers):
            avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
            center = avg_cam_center
            dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
            diagonal = np.max(dist)
            return center.flatten(), diagonal
        center, diagonal = get_center_and_diag(cam_centers.cpu().numpy())
        radius = diagonal * 1.1

        self.config.model.radius = float(radius)

    def init_gsmodel(self, gs_config): 

        self.dataset.cam_center = torch.zeros([3]).to(self.rank)
        self.model.gaussians.random_init(self.config.model.radius, self.dataset.cam_center, n_points=100_000)
        print(f"Random init from camera center")

        self.get_save_path(f"{self.config.dataset.scene}")
        self.model.gaussians.training_setup(gs_config)

    def on_train_start(self) -> None:
        self.current_iter = 0 
        ### For 3dgs model setting
        self.dataset = self.trainer.datamodule.train_dataloader().dataset
        w, h = self.dataset.w, self.dataset.h
        fx, fy = self.dataset.fx, self.dataset.fy

        self.model.w, self.model.h = w, h
        self.model.ori_w, self.model.ori_h = w, h

        self.model.fovx = 2 * math.atan(w / (2 * fx))
        self.model.fovy = 2 * math.atan(h / (2 * fy))

        self.model.tanfovx = math.tan(self.model.fovx * 0.5)
        self.model.tanfovy = math.tan(self.model.fovy * 0.5)

        ### set 3dgs config here

        #self.set_config_debug()
        gs_config = self.config.model.nbv if self.config.nbv.use else self.config.model.complete
        gs_config.output_path = os.path.abspath(os.path.join(self.config.save_dir, self.config.dataset.scene))

        self.init_gsmodel(gs_config)

        ### Default nbv on
        self.model.p = self.config.model.p
        self.model.apply_mask = self.config.model.mc_use

        self.parse_n_images()

        if self.config.nbv.init_style == "random": 
            self.training_ids = torch.randperm(len(self.dataset.all_images))[:self.n_images].to(self.rank)
        elif self.config.nbv.init_style == "uniform": 
            translations = self.dataset.all_c2w[:,:,3]
            self.training_ids = select_k_uniform_points(translations, k=self.n_images)

        self.new_view_id = 0

        zoom_in = self.config.nbv.downsample

        if self.config.nbv.use: 
            for i, idx in enumerate(self.training_ids): 
                gt_path = self.get_save_path(f"nbv/gt_{i+1}_0_step_0_idx_{idx}.png")

                ### concate image and fg_mask as rgba image
                img = torch.cat([self.dataset.all_images[idx], self.dataset.all_fg_masks[idx][...,None]], dim=-1)
                torchvision.utils.save_image(img.permute(2,0,1), gt_path)

                c2w_path = self.get_save_path(f"nbv/c2w_{i+1}_0_step_0_idx_{idx}.npy")
                np.save(c2w_path, nbv.convert3x4_4x4(self.dataset.all_c2w[idx]).clone().detach().cpu().numpy())

        if self.config.nbv.max_num_new_imgs is None: 
            self.config.nbv.max_num_new_imgs  = len(self.dataset.all_images) - self.n_images

        self.set_training_video_traj()
        self.nll_loss = nn.GaussianNLLLoss()
        return super().on_train_start()

    def set_training_video_traj(self): 
        total_frames = len(self.trainer.datamodule.val_dataloader().dataset.all_c2w)

        steps_per_frame = self.config.trainer.val_check_interval
        circle_per_train = self.config.system.video_circles
        frames_per_circle = int(self.config.trainer.max_steps//steps_per_frame//circle_per_train)
        frames_per_circle = max(frames_per_circle, 1)
        ### select frames_per_circle elements uniformly from self.triaining_video_traj
        self.selected_frames = torch.linspace(0, total_frames-1, frames_per_circle, dtype=torch.long)
        self.training_video_frame_idx = 0
        self.train_video_images = []

    def get_a_known_c2w(self, use_train=True): 
        #self.trainer.datamodule.train_dataloader().dataset // self.trainer.datamodule.predict_dataloader().dataset
        dataset = self.dataset
        if self.training: 
            random_id = self.training_ids[torch.randperm(len(self.training_ids))[-1]]
        else: 
            random_id = dataset.all_c2w[torch.randperm(len(dataset.all_c2w))[-1]]

        c2w = dataset.all_c2w[random_id].detach().clone().to(self.rank)
        return c2w

    def forward(self, batch):
        rays = batch['rays']
        c2w = batch['c2w']
        out = self.model(rays, c2w)

        return out 

    def get_index_from_training_ids(self, except_last_k): 
        # Chose one images
        index = torch.randint(0, len(self.training_ids)-except_last_k, size=(1,))
        index = self.training_ids[index].to(self.dataset.all_images.device)

        return index

    def get_img_index(self, batch): 
        if self.stage == "test": # testing
            index = batch['index'].to(self.dataset.all_images.device)

        elif self.stage == "val": 
            val_only = False
            if val_only: 
                index = torch.zeros([1]).long()
            else: 
                index = self.training_video_frame_idx%len(self.selected_frames)
                index = self.selected_frames[index].unsqueeze(0)

        elif self.stage == "train": 
            if self.last_nbv < self.current_iter < (self.last_nbv+self.config.nbv.train_nbv_steps): 
                # Chose nbv
                index = self.training_ids[-1:].to(self.dataset.all_images.device)

            else:
                # Chose from all
                index = self.get_index_from_training_ids(except_last_k=0)

        return index


    def preprocess_data(self, batch, stage):
        index = self.get_img_index(batch)

        c2w = self.dataset.all_c2w[index][0]

        rgb = self.dataset.all_images[index].view(-1, self.dataset.all_images.shape[-1]).to(self.rank)
        fg_mask = self.dataset.all_fg_masks[index].view(-1).to(self.rank)


        if self.dataset.apply_mask or stage in ['train']:
            if self.config.model.background_color == 'white':
                self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank)
            elif self.config.model.background_color == 'random':
                self.model.background_color = torch.rand((3,), dtype=torch.float32, device=self.rank)
            elif self.config.model.background_color == 'black':
                self.model.background_color = torch.zeros((3,), dtype=torch.float32, device=self.rank)
            else:
                raise NotImplementedError
        else:
            self.model.background_color = torch.ones((3,), dtype=torch.float32, device=self.rank)
        if self.dataset.apply_mask:
            rgb = rgb * fg_mask[...,None] + self.model.background_color * (1 - fg_mask[...,None])        

        batch.update({
            'c2w': c2w, 
            'rays': None,
            'rgb': rgb,
            'fg_mask': fg_mask
        })      


    def optimize_step(self, loss, sch_step=True): 
        opts = self.optimizers()
        if not isinstance(opts, list):
            opts = [opts]
        for optim in opts:
            optim.zero_grad()
        self.manual_backward(loss)
        for optim in opts:
            optim.step()

        if sch_step: 
            sch = self.lr_schedulers()
            sch.step()

        for optim in opts:
            optim.zero_grad()

        return

    def forward_loss(self, batch): 
        c2w = batch['c2w']
        out = self.model.forward(None, c2w)

        loss = 0.

        loss_rgb_mse = F.mse_loss(out['comp_rgb_full'].view(-1, 3), batch['rgb'])
        self.log('train/loss_rgb_mse', loss_rgb_mse)
        loss += loss_rgb_mse * self.C(self.config.system.loss.lambda_rgb_mse)

        loss_rgb_l1 = F.l1_loss(out['comp_rgb_full'].view(-1, 3), batch['rgb'])
        self.log('train/loss_rgb', loss_rgb_l1)
        loss += loss_rgb_l1 * self.C(self.config.system.loss.lambda_rgb_l1)        

        loss_rgb_ssim = 1 - ssim(out['comp_rgb_full'].view(self.model.h, self.model.w, 3).permute(2,0,1)[None,...], batch['rgb'].view(self.model.h, self.model.w, 3).permute(2,0,1)[None,...])
        self.log('train/loss_rgb', loss_rgb_ssim)
        loss += loss_rgb_ssim * self.C(self.config.system.loss.lambda_rgb_ssim)        


        self.log('train/loss', float(loss.item()), prog_bar=True)

        if self.current_iter % self.config.system.save_training_video_interval == 0: 
            def save_video(): 
                save_path = self.get_save_path(f"training_video.mp4")
                self.train_video_writer = imageio.get_writer(save_path, mode='I', fps=5, codec='libx264', bitrate='16M', macro_block_size=None)
                for img in self.train_video_images: 
                    height, width, _ = img.shape
                    img = cv2.resize(img, (width - (width % 2), height - (height % 2)))
                    self.train_video_writer.append_data(img)
                self.train_video_writer.close()

            W, H = self.dataset.img_wh
            img = self.get_image_grid_([
                {'type': 'rgb', 'img': batch['rgb'].clone().detach().cpu().view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
                {'type': 'rgb', 'img': out['comp_rgb_full'].clone().detach().cpu().view(H, W, 3), 'kwargs': {'data_format': 'HWC'}}
            ])

            if self.config.system.save_training_video: 
                self.train_video_images.append(img[...,::-1])
                save_video()
                self.training_video_frame_idx += 1

        return loss, out

    def training_step(self, batch, batch_idx): 
        if self.config.system.debug: 
            self.max_steps = self.current_iter + 1
            self.debug()
            self.debug_done = True
            print("debugging")
            os._exit(0)

        return_values = self.training_step_3dgs(batch, batch_idx)

        return return_values

    def training_step_end(self, batch_parts_outputs): 
        ### Export pose visualization with pcd
        if (self.current_iter % self.config.system.pose_vis_step == 0) and (self.current_iter != 0): 
            self.export_pcd_with_pose()

        if (self.config.nbv.use == False) or (self.current_iter == 0) or (self.new_view_id >= self.config.nbv.max_num_new_imgs) or \
            (self.current_iter < self.config.nbv.add_nbv_start_steps) or ((self.current_iter-self.config.nbv.add_nbv_start_steps) % self.config.nbv.add_nbv_n_steps != 0): 
            pass
        else: 
            nbv_c2w, imgs, min_loss = self.estimate_nbv()
            if nbv_c2w is not None: 
                self.add_nbv(nbv_c2w, imgs)
            self.last_nbv = self.current_iter

            gs_config = self.config.model.complete if (self.config.nbv.use is False or self.new_view_id == self.config.nbv.max_num_new_imgs) else self.config.model.nbv
            gs_config.output_path = os.path.abspath(os.path.join(self.config.save_dir, self.config.dataset.scene))
            self.init_gsmodel(gs_config)
            if self.new_view_id == self.config.nbv.max_num_new_imgs: 
                self.model.p = 0
                self.model.apply_mask = False

        ### Export add nbv image train process
        if self.config.nbv.save_training_video:
            self.save_nbv_training_video() 

        if self.current_iter % self.config.checkpoint.every_n_train_steps == 0: 
            path = self.model.gaussians.save_ply_1(self.current_iter)
            self.latest_ply_path = path

        if self.local_iter == self.config.model.complete.position_lr_max_steps: # bug here, not considering complete step larger than nbv step
            path = self.model.gaussians.save_ply_1(self.current_iter)
            self.latest_ply_path = path
            self.end_training_flag = True

    def training_step_3dgs(self, batch, batch_idx):
        loss, out = self.forward_loss(batch) # out is model out
        visibility_filter = out["visibility_filter"]
        viewspace_point_tensor = out["viewspace_points"]
        radii = out["radii"]

        for name, value in self.config.system.loss.items():
            if name.startswith('lambda'):
                self.log(f'train_params/{name}', self.C(value))

        self.current_iter += 1
        self.model.gaussians.iteration = self.current_iter

        #add_nbv_n_steps
        if self.config.nbv.use:
            if self.new_view_id < self.config.nbv.max_num_new_imgs: 
                local_iter = self.current_iter % self.config.nbv.add_nbv_n_steps # Assume start == add n step
                gs_config = self.model.config.nbv
            else: 
                local_iter = self.current_iter - self.config.nbv.add_nbv_n_steps * (self.config.nbv.max_num_new_imgs)
                gs_config = self.model.config.complete
        else: 
            local_iter = self.current_iter
            gs_config = self.model.config.complete
        self.local_iter = local_iter
        self.model.gaussians.update_learning_rate(local_iter)
        # Every 1000 its we increase the levels of SH up to a maximum degree

        if local_iter % 1000 == 0:
            self.model.gaussians.oneupSHdegree()
        loss.backward()

        with torch.no_grad():
            # Densification
            if local_iter < gs_config.densify_until_iter:
                # Keep track of max radii in image-space for pruning
                self.model.gaussians.max_radii2D[visibility_filter] = torch.max(self.model.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
                self.model.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)

                if local_iter > gs_config.densify_from_iter and local_iter % gs_config.densification_interval == 0:
                    size_threshold = 20 if local_iter > gs_config.opacity_reset_interval else None
                    self.model.gaussians.densify_and_prune(gs_config.densify_grad_threshold, 0.005, self.config.model.radius, size_threshold)
                    # In 3DGS self.model.radius is 1.1 multiply max camera idstance from average center
                if (local_iter+1) % gs_config.opacity_reset_interval == 0 or (self.config.model.background_color == 'white' and local_iter == gs_config.densify_from_iter):
                    self.model.gaussians.reset_opacity()

                if local_iter > 200000000000 and local_iter % gs_config.densification_interval == 0:
                    self.model.gaussians.spread(self.config.model.radius)

        # Optimizer step
        self.model.gaussians.optimizer.step()
        self.model.gaussians.optimizer.zero_grad(set_to_none=True)

        return {
            'loss': loss
        }

    def set_config_debug(self): 
        self.config.checkpoint.every_n_train_steps = 99
        self.config.model.position_lr_max_steps = 100
        self.config.trainer.val_check_interval = 99

    def save_nbv_training_video(self): 
        idx = self.training_ids[-1]
        self.model.eval()

        if (self.current_iter < self.last_nbv) or (self.current_iter > (self.last_nbv+self.config.nbv.train_nbv_steps)): 
            pass
        elif self.last_nbv == self.current_iter:  
            if hasattr(self, "nbvimage_optim_writer"): 
                self.nbvimage_optim_writer.close()
            save_path = self.get_save_path(f"nbv/newimg_{self.training_ids[-1].item()}_optimization_{self.current_iter}.mp4")
            self.nbvimage_optim_writer = imageio.get_writer(save_path, mode='I', fps=5, codec='libx264', bitrate='16M')
        elif (self.last_nbv < self.current_iter < (self.last_nbv+self.config.nbv.train_nbv_steps)) \
            and self.current_iter % self.config.nbv.video_save_img_interval == 0: 
            img = self.render_one_image(c2w=self.dataset.all_c2w[idx])
            self.nbvimage_optim_writer.append_data(img[...,::-1])
        elif self.current_iter == (self.last_nbv+self.config.nbv.train_nbv_steps): 
            self.nbvimage_optim_writer.close()

        self.model.train()

    def validation_step(self, batch, batch_idx):
        rv = {}

        c2w = batch['c2w']

        out = self.model.forward(None, c2w)
        rgb = out["comp_rgb"]

        psnr = self.criterions['psnr'](rgb.view(-1, 3), batch['rgb']).detach().clone()

        W, H = self.dataset.img_wh

        rv.update({
        'psnr': psnr,
        'index': batch['index'],
        'nll': 0.0
        })
        ### NBV estimation

        return rv

    def validation_step_end(self, out):
        # aggregate outputs from different devices when using DP
        pass

    def validation_epoch_end(self, out):
        if len(out[0]) == 0: 
            return

        out = self.all_gather(out)
        if self.trainer.is_global_zero:
            out_set = {}
            for step_out in out:
                # DP
                if step_out['index'].ndim == 1:
                    out_set[step_out['index'].item()] = {'psnr': step_out['psnr'], 'nll': step_out['nll']}
                # DDP
                else:
                    for oi, index in enumerate(step_out['index']):
                        out_set[index[0].item()] = {'psnr': step_out['psnr'][oi], 'nll': step_out['nll'][oi]}
            psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
            nll = torch.mean(torch.stack([o['nll'] for o in out_set.values()]))
            self.log(f'val/psnr', psnr, prog_bar=True, rank_zero_only=True)         
            self.log(f'val/nll', nll, prog_bar=True, rank_zero_only=True)         

    def on_test_start(self): 
        self.criterions.update({"lpips": LPIPS().to(self.rank)})
        self.dataset = self.trainer.datamodule.test_dataloader().dataset
        w, h = self.dataset.w, self.dataset.h
        self.model.w, self.model.h = w, h
        self.model.ori_w, self.model.ori_h = w, h

        fx, fy = self.dataset.fx, self.dataset.fy

        self.model.fovx = 2 * math.atan(w / (2 * fx))
        self.model.fovy = 2 * math.atan(h / (2 * fy))

        self.model.tanfovx = math.tan(self.model.fovx * 0.5)
        self.model.tanfovy = math.tan(self.model.fovy * 0.5)

        ### set 3dgs config here

        gs_config = self.config.model.nbv if self.config.nbv.use else self.config.model.complete
        gs_config.output_path = os.path.abspath(os.path.join(self.config.save_dir, self.config.dataset.scene))

        ### if trained previously 
        if self.end_training_flag is False: 
            self.init_gsmodel(gs_config)

    def test_step_active(self, batch, batch_idx): 
        c2w = batch['c2w']
        out = self.model(None, c2w)
        rgb = out['comp_rgb_full'].to(batch['rgb'])

        psnr = self.criterions['psnr'](rgb.view(-1, 3), batch['rgb']).detach().clone()
        mse = ((rgb.view(-1, 3) - batch['rgb'])**2)
        mae = ((rgb.view(-1, 3) - batch['rgb'])).abs()
        #error = ((rgb - batch['rgb'])).abs().mean(dim=-1)

        W, H = self.dataset.img_wh
        ssim = self.criterions['ssim'](rgb.view(H, W, 3).permute(2,0,1)[None,...], batch['rgb'].view(H, W, 3).permute(2,0,1)[None,...]).detach().clone()
        lpips = self.criterions['lpips'](rgb.view(H, W, 3).permute(2,0,1)[None,...], batch['rgb'].view(H, W, 3).permute(2,0,1)[None,...]).detach().clone()

        self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", [
            {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
            {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
        ])

        image_grid = [
            {'type': 'rgb', 'img': batch['rgb'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
            {'type': 'rgb', 'img': out['comp_rgb_full'].view(H, W, 3), 'kwargs': {'data_format': 'HWC'}},
        ]

        self.save_image_grid(f"it{self.global_step}-test/{batch['index'][0].item()}.png", image_grid)

        self.log('test/psnr', float(psnr.item()), prog_bar=True)
        return {
            'psnr': psnr,
            'ssim': ssim, 
            'lpips': lpips, 
            'index': batch['index'], 
        }     

    def test_step(self, batch, batch_idx):
        rv = self.test_step_active(batch, batch_idx)
        return rv


    def test_epoch_end(self, out):
        """
        Synchronize devices.
        Generate image sequence using test outputs.
        """
        curve = self.model.var_curve

        ### draw this as a curve and save as image
        save_path = self.get_save_path(f"var_curve.png")
        plot.plot_loss_curve(np.array(range(len(curve))), np.array(curve), save_path=save_path)

        out = self.all_gather(out)
        if self.trainer.is_global_zero:
            out_set = {}
            for step_out in out:
                # DP
                if step_out['index'].ndim == 1:
                    out_set[step_out['index'].item()] = {'psnr': step_out['psnr'], 'ssim': step_out['ssim'], 'lpips': step_out['lpips']}
                # DDP
                else:
                    for oi, index in enumerate(step_out['index']):
                        out_set[index[0].item()] = {'psnr': step_out['psnr'][oi], 'ssim': step_out['ssim'][oi], 'lpips': step_out['lpips'][oi]}
            psnr = torch.mean(torch.stack([o['psnr'] for o in out_set.values()]))
            ssim = torch.mean(torch.stack([o['ssim'] for o in out_set.values()]))
            lpips = torch.mean(torch.stack([o['lpips'] for o in out_set.values()]))

            self.log('test/psnr', psnr, prog_bar=True, rank_zero_only=True)    
            self.log('test/ssim', ssim, prog_bar=True, rank_zero_only=True)
            self.log('test/lpips', lpips, prog_bar=True, rank_zero_only=True)    

            self.test_psnr = psnr.item()
            self.test_ssim = ssim.item()
            self.test_lpips = lpips.item()

            print(f"Eval results: test_psnr={self.test_psnr}, test_ssim={self.test_ssim}, test_lpips={self.test_lpips}")

            self.save_img_sequence(
                f"it{self.global_step}-test",
                f"it{self.global_step}-test",
                '(\d+)\.png',
                save_format='mp4',
                fps=20
            )

        print(f"\nSaving to {self.config.save_dir}\n")

    def export(self, save_path):
        if self.config.export.use: 
            mesh = self.model.export(self.config.export)
            self.save_mesh(
                save_path,
                **mesh
            ) 

    @torch.no_grad()
    def export_mesh(self): 
        #self.render_from_depth()
        self.model.eval()
        mesh = self.model.isosurface(resolution=256)
        self.save_mesh(
            f"it{self.current_iter}-{self.config.model.geometry.isosurface.method}{self.config.model.geometry.isosurface.resolution}.obj",
            mesh['v_pos'],
            mesh['t_pos_idx'],
        )
        self.model.train()
        return


    def add_nbv(self, nbv_c2w, imgs=None, no_render=False): 
        dataset = self.trainer.datamodule.train_dataloader().dataset
        self.training_ids = torch.cat([self.training_ids, torch.tensor([len(dataset.all_c2w)]).to(self.rank)])
        new_view_id = len(self.training_ids) - self.n_images
        gt_filename = f"gt_{self.n_images}_{new_view_id}_step_{self.current_iter}.png"
        blenderlog_filename = f"renderlog_{self.n_images}_{new_view_id}_step_{self.current_iter}.txt"
        img_filename = f"img_{self.n_images}_{new_view_id}_step_{self.current_iter}.png"
        c2w_filename = f"c2w_{self.n_images}_{new_view_id}_step_{self.current_iter}.npy"
        img_current_filename = f"curr_{self.n_images}_{new_view_id}_step_{self.current_iter}.png"

        gt_path = self.get_save_path(f"nbv/{gt_filename}")
        blenderlog_path = self.get_save_path(f"nbv/{blenderlog_filename}")
        c2w_path = self.get_save_path(f"nbv/{c2w_filename}")

        np.save(c2w_path, nbv.convert3x4_4x4(nbv_c2w).cpu().numpy())

        self.save_image_grid(filename=f"nbv/{img_filename}", imgs=imgs)

        self.new_view_id = new_view_id

        if no_render is False: 
            if self.config.dataset.name == "artfletch": 
                nbv.render_gt.blender_render_one_view_SketchModels(cmtx_path=os.path.abspath(c2w_path), save_path=gt_path, log_path=blenderlog_path, model_name=self.config.dataset.scene)
            else: 
                nbv.render_gt.blender_render_one_view(cmtx_path=os.path.abspath(c2w_path), save_path=gt_path, log_path=blenderlog_path, model_name=self.config.dataset.scene)

            if not os.path.exists(gt_path): 
                raise ValueError(f"Gt image {gt_path} not rendered. ")
            dataset.add_one_image(nbv_c2w, gt_path)
        self.model.train()

    def estimate_nbv(self): # estimate_nbv
        if self.config.nbv.ig == "fisher": # or curv or others
            self.model.train()
        else: 
            self.model.eval()
        self.model.train()

        self.model.background_color = torch.ones((3,), dtype=torch.float32).to(self.rank)
        print(f"\nStart estimating NBV with {self.planner.config.ig}\n")

        loss_list = []
        position_list = []

        self.planner.global_step = self.current_iter
        self.planner.save_path = self.get_save_path("nbv_record/")
        self.planner.criterions = self.criterions
        camera_poses_coarse, poses = self.planner.planner_init(self.trainer.datamodule.train_dataloader().dataset.all_c2w[self.training_ids], self.trainer.datamodule.train_dataloader().dataset)
        temperature = 0.01

        if poses is not None and self.config.nbv.get("save_candidates", "False"): 
            np.save(self.get_save_path(f"nbv/{self.current_iter}_nbv_search_fine_candidates.npy"), poses.cpu().numpy())
            if camera_poses_coarse is not None: 
                np.save(self.get_save_path(f"nbv/{self.current_iter}_nbv_search_coarse_candidates.npy"), camera_poses_coarse.cpu().numpy())

        if poses is not None and self.config.nbv.vis_candidate_pose_silhouette: 
            def render_poses(poses_render, save_path): 
                save_imgs = []
                poses_render = poses_render.float().detach()
                for c2w in poses_render: 
                    # TODO: add render here
                    save_imgs.append(None)
                save_imgs = torch.cat(save_imgs, dim=0)
                save_imgs = torchvision.utils.make_grid(save_imgs, nrow=10, padding=2, pad_value=1).detach()
                torchvision.utils.save_image(save_imgs, save_path)

            if camera_poses_coarse is not None: 
                save_path = self.get_save_path(f"nbv/{self.current_iter}_nbv_search_coarse_candidates.png")
                render_poses(camera_poses_coarse, save_path)
            save_path = self.get_save_path(f"nbv/{self.current_iter}_nbv_search_fine_candidates.png")
            render_poses(poses, save_path)

            export_config = self.config.export
            mesh = self.model.export(export_config, resolution=256)
            if self.config.nbv.planner_name == "geometry_search_woqan": 
                poses = poses[:poses.shape[0]//2]

        best_param  = [self.planner.params, 0]
        report_freq = 100000

        max_step = min(len(self.planner.c2w_candidates) if self.planner.c2w_candidates is not None else float("inf"), self.config.nbv.max_step)

        video_imgs = []
        self.planner.current_step = 0

        for i in (pbar := tqdm(range(max_step))): 
            def step(): 
                # if loss nan then reset, no need for search method
                self.planner.current_step = i
                loss, output_imgs = self.planner.step(record=True)
                return loss, output_imgs
        
            #bias_param = self.planner.state_dict()['model.geometry.network.layers.0.bias']
            #print([param for param in self.planner.params])
        
            if i+1 % report_freq == 0: 
                with LineProfiler(step) as prof: 
                    loss, output_imgs = step()
                    #prof.display()
                    reporter = MemReporter()
                    reporter.report()
            else: 
                loss, output_imgs = step()

            if loss < best_param[1]: 
                best_param  = [self.planner.params, loss]
            loss_list.append(loss)
            position_list.append(self.planner.params[0][:,3])
            pbar.set_description(f"{loss}")
            #torch.cuda.empty_cache()

            if self.config.nbv.save_video and output_imgs is not None: 
                ### save image
                img = self.get_image_grid_(imgs=output_imgs)
                img = img[...,::-1]
                # Resize the image using cv2.resize
                height, width, _ = img.shape
                img = cv2.resize(img, (width - (width % 2), height - (height % 2))) # avoid odd image size
                video_imgs.append(img)
            del output_imgs

        if self.config.nbv.save_video: 
            save_path = self.get_save_path(f"nbv/{self.current_iter}_nbv_optimization.mp4")
            writer = imageio.get_writer(save_path, mode='I', fps=10, codec='libx264', bitrate='16M', macro_block_size=1)
            
            video_imgs = np.array(video_imgs)
            ### If sort by loss
            video_imgs = video_imgs[np.argsort(np.array(loss_list))]
            
            for video_img in video_imgs:
                ### Sort the video frames based on the loss_list
                writer.append_data(video_img)
            writer.close()

        self.planner.params = best_param[0]

        loss, output_imgs = self.planner.forward(record=True)
        c2w = self.planner.params[0].detach()

        loss_img = plot.plot_loss_curve(range(len(loss_list)), loss_list, resize=(self.planner.config.w, self.planner.config.h))
        output_imgs.append({'type': 'rgb', 'img': loss_img[...,:3], 'kwargs': {'data_format': 'HWC', 'data_range': (0, 255)}})

        if self.config.system.debug: 
            self.add_nbv(c2w, output_imgs, no_render=False)

        self.log("nbv/estimate_nbv_ig", -loss)
        search_traj = torch.cat([torch.stack(position_list, dim=0).to(self.rank), torch.tensor(loss_list).to(self.rank)[:,None]], dim=-1)
        self.planner.prior_point_score = search_traj
        return c2w, output_imgs, loss

    def erase_out(self): 
        self.model.gaussians.save_ply(self.get_save_path("filtered.ply"), clean=0.8)

    def debug(self): 
        self.erase_out()
