import os

import einops
import lightning
import torchvision.utils
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
import numpy as np
import torch
import torch.nn as nn
from nerfacc import OccupancyGrid, ray_marching, rendering
from utils.examples.datasets.utils import Rays, namedtuple_map
from torch.utils.data._utils.collate import collate, default_collate_fn_map
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from typing import Optional
import mcubes
import trimesh
import torch.nn.functional as F
from tqdm import trange
import imageio
from kornia.feature import LoFTR
from kornia.geometry.epipolar import (
    find_fundamental,
    essential_from_fundamental,
    decompose_essential_matrix,
)
import cv2
import matplotlib.pyplot as plt
import glob

from hypernet_core import Hypernet
from load_objaverse import ObjaverseDataset  # , load_objaverse_data
from run_nerf import create_ngp_mlp, render, render_path_torch
from run_nerf_helpers import get_rays

from pdb import set_trace as bb
from time import time
from torchvision.transforms import ToPILImage
from info_nce import InfoNCE

from utils.coordinates import ang_to_matrix, pose_to_matrix, matrix_to_ang
from utils.examples.datasets.utils import Rays
from utils.examples.radiance_fields.mlp import MLPDensityField

from utils.images_utils import put_text_on_image, auto_permute_image
from utils.model_tools import hook_fn_decorator, summary
import math

try:
    from nerfacc import ContractionType
    GRID_CONTRACTION_TYPE = ContractionType.AABB
except ImportError:
    pass
from pytorch3d.transforms import (
    matrix_to_quaternion,
    standardize_quaternion,
    quaternion_to_axis_angle,
    quaternion_to_matrix,
    matrix_to_euler_angles,
    euler_angles_to_matrix,
    axis_angle_to_matrix,
    matrix_to_axis_angle,
)

import logging
from lpips import LPIPS

tensor_to_pil = ToPILImage()
eps = 1e-8
img2mse = lambda x, y: torch.mean((x - y) ** 2)
img2l1 = lambda x, y: torch.mean(torch.abs(x - y))
mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device))
psnr = lambda x, y, eps=1e-8: -10.0 * torch.log10(torch.mean(torch.square(x - y)) + eps)

# # if you want use render_image_with_ngp, you need to conda activate nerf_acc0.3
def render_image_with_ngp(
    # scene
    radiance_field: torch.nn.Module,
    occupancy_grid: OccupancyGrid,
    rays: Rays,
    scene_aabb: torch.Tensor,
    # rendering options
    near_plane: Optional[float] = None,
    far_plane: Optional[float] = None,
    render_step_size: float = 1e-3,
    color_bkgds: Optional[torch.Tensor] = None,
    cone_angle: float = 0.0,
    alpha_thre: float = 0.0,
    # other configs
    grid_weights  = None,
    device='cuda:0',
    training=True,
    max_elements=512,
    test_chunk_size: int = 8192,
    n_total_cells = 884736,
):
    # bb()
    """Render the pixels of an image."""
    rays_shape = rays.origins.shape
    if len(rays_shape) == 3:
        height, width, _ = rays_shape
        num_rays = height * width
        rays = namedtuple_map(
            lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays
        )
    else:
        num_rays, _ = rays_shape
        

    def sigma_fn(t_starts, t_ends, ray_indices):
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
        positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
        _, density = radiance_field(positions)
        return density

    def rgb_sigma_fn(t_starts, t_ends, ray_indices):
        t_origins = chunk_rays.origins[ray_indices]
        t_dirs = chunk_rays.viewdirs[ray_indices]
        positions = t_origins + t_dirs * (t_starts + t_ends) / 2.0
        rgb, sigmas = radiance_field(positions)
        return rgb, sigmas
    # bb()
    if grid_weights is not None:
        grid_weights = torch.load(grid_weights)
        grid_weights['_binary'] = grid_weights['_binary'].to_dense()
        grid_weights['occs'] = torch.empty([n_total_cells])
        occupancy_grid.load_state_dict(grid_weights)


    results = []
    # bb()
    chunk = (
        torch.iinfo(torch.int32).max
        if training
        else test_chunk_size
    )
    n_rendering_samples = 0 
    for i in range(0, num_rays, chunk):
        chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays)
        # bb()
        ray_indices, t_starts, t_ends = ray_marching(
            chunk_rays.origins,
            chunk_rays.viewdirs,
            scene_aabb=scene_aabb.to(chunk_rays.origins.device),
            grid=occupancy_grid,
            sigma_fn=sigma_fn,
            near_plane=near_plane,
            far_plane=far_plane,
            render_step_size=render_step_size,
            stratified=radiance_field.training,
            cone_angle=cone_angle,
            alpha_thre=alpha_thre,
        )   

        rgb, opacity, depth = rendering(
            t_starts,
            t_ends,
            ray_indices,
            n_rays=chunk_rays.origins.shape[0],
            rgb_sigma_fn=rgb_sigma_fn,
            render_bkgd=torch.tensor(color_bkgds, dtype=torch.float32).to(chunk_rays.origins.device),
        )

        n_rendering_samples += t_starts.numel()
        chunk_results = [rgb, opacity, depth]
        results.append(chunk_results)

    colors, opacities, depths = collate(
        results,
        collate_fn_map={
            **default_collate_fn_map,
            torch.Tensor: lambda x, **_: torch.cat(x, 0),
        },
    )
    
    return (
        colors.view((*rays_shape[:-1], -1)),
        opacities.view((*rays_shape[:-1], -1)),
        depths.view((*rays_shape[:-1], -1)),
        n_rendering_samples
    )

class HoNeRF2Vec(lightning.LightningModule):
    def __init__(self, args, data_module):
        super().__init__()
        self.args = args
        self.near = data_module.near
        self.far = data_module.far
        self.H = data_module.H
        self.W = data_module.W
        self.K = data_module.K
        self.hwf = data_module.hwf
        self.aabb = torch.tensor(self.args.aabb, dtype=torch.float32)
        self.render_step_size = (
            (self.aabb[3:] - self.aabb[:3]).max()
            * math.sqrt(3) / self.args.grid_config_n_samples
        ).item()
        # bb()
        if self.args.use_patch:
            lpips_model = LPIPS(net='vgg').to(self.device)
            lpips_model.eval()
            for p in lpips_model.parameters():
                p.requires_grad = False
            object.__setattr__(self, "_lpips_fn", lpips_model) 

        self.occupancy_grid = OccupancyGrid(
            roi_aabb=self.aabb,
            resolution=96,
            contraction_type=GRID_CONTRACTION_TYPE,
        ) 
        self.coords = None

        render_kwargs_train, render_kwargs_test, _, _, optimizer = create_ngp_mlp(args)

        render_kwargs_train.update({"near": self.near, "far": self.far})
        render_kwargs_test.update({"near": self.near, "far": self.far})
        self.render_kwargs_train_modules, self.render_kwargs_train_args = (
            self.split_modules_args(render_kwargs_train)
        )
        _, self.render_kwargs_test_args = self.split_modules_args(render_kwargs_test)

    def on_fit_start(self):
        self.get_model_summary()
   
    def get_model_summary(self):
        if not self.global_rank == 0:
            return
        if self.args.enable_weight_init:
            model_summary_dict = {
                'main-hypernet': summary(self.hypernet)+summary(self.hypernet.init_weight_hypernet),
                'img-encoder': summary(self.render_kwargs_train_modules['image_encoder']),
                'main-init_weight_hypernet': summary(self.hypernet.init_weight_hypernet),
                'main-decoder': summary(self.hypernet.decoders),
                'main-opt_layers': summary(self.hypernet.opt_layer),

            }
        else:
            model_summary_dict = {
                'main-hypernet': summary(self.hypernet),
                'img-encoder': summary(self.render_kwargs_train_modules['image_encoder']),
                'main-encoder': summary(self.hypernet.encoders),
                'main-decoder': summary(self.hypernet.decoders),
                'main-opt_layers': summary(self.hypernet.opt_layer),
            }
        # Function to format parameter count
        def format_param_count(count):
            if count >= 1_000_000:
                return f"{count/1_000_000:.2f}M"
            elif count >= 1_000:
                return f"{count/1_000:.2f}K"
            else:
                return f"{count}"
        
        # Convert model summary dict to a formatted string table
        model_summary_str = "Model Component Summary:\n"
        model_summary_str += "-" * 50 + "\n"
        model_summary_str += f"{'Component':<25} | {'Parameter Count':>20}\n"
        model_summary_str += "-" * 50 + "\n"
        
        # Calculate total params only from main components
        main_components = ['main-hypernet', 'img-encoder']
        total_params = sum(model_summary_dict[comp] for comp in main_components)
        
        for component, param_count in model_summary_dict.items():
            param_str = format_param_count(param_count)
            model_summary_str += f"{component:<25} | {param_str:>20}\n"
            
        model_summary_str += "-" * 50 + "\n"
        total_str = format_param_count(total_params)
        model_summary_str += f"{'Total (main components)':<25} | {total_str:>20}\n"
        model_summary_str += "-" * 50 + "\n"

        print(model_summary_str)
        open(os.path.join(self.logger.log_dir, "model_summary.txt"), "w").write(model_summary_str)

    @property
    def hypernet(self):
        return self.render_kwargs_train_modules["network_fn"]

    def on_load_checkpoint(self, checkpoint):
        pass

    def split_modules_args(self, arg_dict):
        module_dict = nn.ModuleDict(
            {k: v for k, v in arg_dict.items() if isinstance(v, nn.Module)}
        )
        args_dict = {k: v for k, v in arg_dict.items() if not isinstance(v, nn.Module)}
        return module_dict, args_dict

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.args.lrate, betas=(0.9, 0.999))
        

    def render_kwargs_train(self):
        return {**self.render_kwargs_train_args, **self.render_kwargs_train_modules}

    def render_kwargs_test(self):
        return {**self.render_kwargs_test_args, **self.render_kwargs_train_modules}

    def encode(self, img, pose):
        B, N, H, W, C = img.shape
        img = einops.rearrange(img, "B N H W C -> B N C H W")
        pose = einops.rearrange(pose, "B N H W -> B N (H W)")
        encoder = self.render_kwargs_train_modules["image_encoder"]
        feat = encoder(img, pose)
        feat = einops.rearrange(feat, "(B N) L D -> B (N L) D", B=B, N=N)
        return feat

    def get_inputs_patch(self, target, pose, mask, patch_size):
        rays_o, rays_d = get_rays(self.H, self.W, self.K, pose)  # (H, W, 3), (H, W, 3)

        # Initialize coords grid if not already done
        if self.coords is None:
            self.coords = torch.stack(
                torch.meshgrid(
                    torch.arange(self.H, device=rays_o.device),
                    torch.arange(self.W, device=rays_o.device),
                    indexing='ij'
                ),
                -1
            )   # (H, W, 2)

        def sample_rays(left, top):
            patch_ro = rays_o[top:top + patch_size, left:left + patch_size]
            patch_rd = rays_d[top:top + patch_size, left:left + patch_size]
            patch_img = target[top:top + patch_size, left:left + patch_size]
            patch_mask = mask[top:top + patch_size, left:left + patch_size]

            return (
                torch.stack([patch_ro.reshape(-1, 3), patch_rd.reshape(-1, 3)], 0),
                patch_img.reshape(-1, 3),
                patch_mask.reshape(-1, 1)
            )

        foreground_mask = mask
        valid_coords = self.coords[foreground_mask]  # (N_fg, 2)
        valid_coords = valid_coords[
            (valid_coords[:, 0] >= 0) & (valid_coords[:, 0] <= self.H - patch_size) &
            (valid_coords[:, 1] >= 0) & (valid_coords[:, 1] <= self.W - patch_size)
        ]
        if self.current_epoch < self.args.precrop_iters and valid_coords.shape[0] > 0:
            rand_index = torch.randint(0, valid_coords.shape[0], (1,)).item()
            y, x = valid_coords[rand_index].tolist()  # (h, w)

            left, top = x, y

            # Check if (x, y) can be center of patch
            if (
                y - patch_size // 2 >= 0 and x - patch_size // 2 >= 0 and
                y + patch_size // 2 < self.H and x + patch_size // 2 < self.W
            ):
                if torch.rand(1).item() < 0.5:
                    top = y - patch_size // 2
                    left = x - patch_size // 2

            batch_rays, target_s, target_m = sample_rays(left, top)
        else:
            top = torch.randint(0, self.H - patch_size + 1, (1,)).item()
            left = torch.randint(0, self.W - patch_size + 1, (1,)).item()
            batch_rays, target_s, target_m = sample_rays(left, top)

        ro, rd = batch_rays.split([1, 1], dim=0)
        return ro, rd, target_s, target_m

    def get_inputs(self, target, pose, mask):
        rays_o, rays_d = get_rays(self.H, self.W, self.K, pose)  # (H, W, 3), (H, W, 3)

        # Initialize coords grid if not already done
        if self.coords is None:
            self.coords = torch.stack(
                torch.meshgrid(
                    torch.linspace(0, self.H - 1, self.H, device=rays_o.device),
                    torch.linspace(0, self.W - 1, self.W, device=rays_o.device),
                ),
                -1,
            )  # (H, W, 2)

        # Helper function to sample rays and targets
        def sample_rays(coords, num_samples):
            coords = coords.reshape(-1, 2)  # Flatten to (H*W, 2)
            if coords.shape[0] == 0:
                select_inds = torch.zeros(0, device=coords.device, dtype=torch.long)
            else:
                select_inds = torch.randint(
                    coords.shape[0], (num_samples,), device=coords.device
                )
            select_coords = coords[select_inds].long()

            ro = rays_o[select_coords[:, 0], select_coords[:, 1]]  # (N_samples, 3)
            rd = rays_d[select_coords[:, 0], select_coords[:, 1]]  # (N_samples, 3)
            target_s = target[
                select_coords[:, 0], select_coords[:, 1]
            ]  # (N_samples, 3)
            target_m = mask[select_coords[:, 0], select_coords[:, 1]]

            # print(
            #     f"try to sample {num_samples} rays, got {ro.shape[0]} rays, coords shape {coords.shape}, target_s shape {target_s.shape}, target_m shape {target_m.shape}"
            # )
            return torch.stack([ro, rd], 0), target_s, target_m

        # Handle precrop sampling if in early epochs
        if self.current_epoch < self.args.precrop_iters:
            foreground_mask = mask
            fg_pixs = foreground_mask.sum()
            bg_pixs = (~foreground_mask).sum()

            num_foreground = int(self.args.N_rand * (1.0 - self.args.precrop_frac))
            num_background = self.args.N_rand - num_foreground

            if num_foreground > fg_pixs:
                num_foreground = fg_pixs
                num_background = self.args.N_rand - num_foreground
            elif num_background > bg_pixs:
                num_background = bg_pixs
                num_foreground = self.args.N_rand - num_background

            # Sample foreground and background separately
            rays_fg, target_fg, mask_fg = sample_rays(
                self.coords[foreground_mask], num_foreground
            )
            rays_bg, target_bg, mask_bg = sample_rays(
                self.coords[~foreground_mask], num_background
            )

            batch_rays = torch.cat([rays_fg, rays_bg], 1)
            target_s = torch.cat([target_fg, target_bg], 0)
            target_m = torch.cat([mask_fg, mask_bg], 0)
            assert (
                target_s.shape[0] == self.args.N_rand
            ), f"target_s.shape[0] = {target_s.shape[0]}, self.args.N_rand = {self.args.N_rand}, fg_pixs = {fg_pixs}, num_foreground = {num_foreground}"
        else:
            # Sample randomly from full image
            batch_rays, target_s, target_m = sample_rays(self.coords, self.args.N_rand)
            assert target_s.shape[0] == self.args.N_rand

        ro, rd = batch_rays.split([1, 1], dim=0)
        return ro, rd, target_s, target_m

    def render_img(self, rays, grid_weights):
        rgb, acc, depth, _ = render_image_with_ngp(
            radiance_field = self.hypernet,
            occupancy_grid = self.occupancy_grid,
            rays = rays,
            scene_aabb = self.aabb,
            # rendering options
            render_step_size=self.render_step_size,
            color_bkgds=self.args.color_bkgds,
            grid_weights=grid_weights,
            training=self.training
        )
        return rgb, acc, depth

    def forward_poses(
        self,
        target_img,
        target_pose,
        target_mask,
        grid_weights=None,
        full_scale=False,
        render_factor=1,
        super_resolution=1.0,
    ):
        ray_origins, ray_directions, target_imgs, target_masks = [], [], [], []
        for pose, image, mask in zip(target_pose, target_img, target_mask):
            # bb()
            if not full_scale:
                if self.args.use_patch:
                    ray_origin, ray_direction, image_sample, mask_sample = self.get_inputs_patch(
                        image, pose, mask, patch_size=64
                    )
                else:
                    ray_origin, ray_direction, image_sample, mask_sample = self.get_inputs(
                        image, pose, mask
                    )
            else:
                # Apply super resolution
                H_sr = int(self.H * super_resolution)
                W_sr = int(self.W * super_resolution)
                K_sr = self.K.copy()
                K_sr[0, 0] *= super_resolution  # fx
                K_sr[1, 1] *= super_resolution  # fy
                K_sr[0, 2] *= super_resolution  # cx
                K_sr[1, 2] *= super_resolution  # cy
                
                ray_origin, ray_direction = get_rays(H_sr, W_sr, K_sr, pose)
                ray_origin = ray_origin[::render_factor, ::render_factor]
                ray_direction = ray_direction[::render_factor, ::render_factor]
                ray_origin = ray_origin.reshape(-1, 3)
                ray_direction = ray_direction.reshape(-1, 3)
                image_sample = image[::render_factor, ::render_factor]
                mask_sample = mask[::render_factor, ::render_factor]

            ray_origins.append(ray_origin)
            ray_directions.append(ray_direction)
            target_imgs.append(image_sample)
            target_masks.append(mask_sample)
            # bb()

        ray_origins = torch.cat(ray_origins, dim=0)
        ray_directions = torch.cat(ray_directions, dim=0)
        target_imgs = torch.stack(target_imgs, dim=0)
        target_masks = torch.stack(target_masks, dim=0)
        rays = Rays(ray_origins, ray_directions)
        rgb, acc, depth = self.render_img(rays, grid_weights)
        if full_scale:
            H_out = int(self.H * super_resolution) // render_factor
            W_out = int(self.W * super_resolution) // render_factor
            rgb = rgb.reshape(-1, H_out, W_out, 3)
            # bb()
            target_imgs = target_imgs.reshape(
                -1, self.H // render_factor, self.W // render_factor, 3
            )
            target_masks = target_masks.reshape(
                -1, self.H // render_factor, self.W // render_factor, 1
            )
            acc = acc.reshape(-1, H_out, W_out, 1)

        return rgb, acc, depth, target_imgs, target_masks

    def save_figures(
        self,
        cond_img,
        pred_target,
        acc,
        depth,
        target_img,
        mask_img,
        render_factor=0,
        mode="test",
    ):
        try:
            logdir = self.logger.log_dir
        except:
            logdir = "logs"
        testsavedir = os.path.join(logdir, f"{mode}set_{self.global_step:06d}")
        cond_img = auto_permute_image(cond_img, add_batch_dim=True)
        pred_target = auto_permute_image(pred_target, add_batch_dim=True)
        target_img = auto_permute_image(target_img, add_batch_dim=True)
        mask_img = auto_permute_image(mask_img, add_batch_dim=True)
        acc = auto_permute_image(acc, add_batch_dim=True)

        nrow = cond_img.shape[0]
        cond_img = torchvision.utils.make_grid(cond_img, nrow=nrow)
        cond_img = put_text_on_image(cond_img, f"cond_img")

        pred_target = torchvision.utils.make_grid(pred_target, nrow=nrow)
        pred_target = put_text_on_image(pred_target, f"pred_target")

        target_img = torchvision.utils.make_grid(target_img, nrow=nrow)
        target_img = put_text_on_image(target_img, f"target_img")

        mask_img = torchvision.utils.make_grid(mask_img, nrow=nrow)
        mask_img = put_text_on_image(mask_img.float(), f"mask_img")

        acc = torchvision.utils.make_grid(acc, nrow=nrow)
        acc = put_text_on_image(acc, f"pred_mask")

        img = torchvision.utils.make_grid(
            [cond_img, pred_target, target_img, acc, mask_img], nrow=1
        )
        return img

    def get_warmup_lr(self):
        # return 0.0
        return 1.0
        # return (min(1000, self.global_step) / 1000) ** 2

    def calc_weight_loss(self):
        if self.gt_main_target is None:
            self.gt_main_target = torch.load(
                "./main-target_net.pth", map_location=self.device
            )
        if self.gt_aux_target is None:
            self.gt_aux_target = torch.load(
                "./aux-target_net.pth", map_location=self.device
            )

        main_target_weights = self.hypernet.generated_weights[-1]
        aux_target_weights = self.hypernet_prop.generated_weights[-1]
        base_main_target_weights = dict(self.hypernet.target_net.named_parameters())
        base_aux_target_weights = dict(self.hypernet_prop.target_net.named_parameters())

        loss = []
        for name, weight in self.gt_main_target.items():
            if name in main_target_weights:
                pred_weight = main_target_weights[name]
            elif name in base_main_target_weights:
                pred_weight = base_main_target_weights[name]
            else:
                continue
            if not weight.requires_grad:
                loss.append(F.mse_loss(pred_weight, weight))

        for name, weight in self.gt_aux_target.items():
            if name in aux_target_weights:
                pred_weight = aux_target_weights[name]
            elif name in base_aux_target_weights:
                pred_weight = base_aux_target_weights[name]
            else:
                continue
            if not weight.requires_grad:
                loss.append(F.mse_loss(pred_weight, weight))

        loss = torch.stack(loss).mean()
        return loss

    def training_step(self, batch, batch_idx):
        cond_img = batch["cond_img"]
        target_img = batch["target_img"]
        target_pose = batch["target_pose"]
        cond_pose = batch["cond_pose"]
        mask_cond = batch["mask_cond"]
        mask_target = batch["mask_target"]
        cond_indices = batch["cond_indices"]
        target_indices = batch["target_indices"]
        uid = batch["uid"]
        obj_id = batch["obj_id"]
        ngp_mlp_weights = batch["ngp_mlp_weights"]
        grid_weights = batch["grid_weights"]

        BS = cond_img.shape[0]

        cond_embedding = self.encode(cond_img, cond_pose)

        psnr_list = []
        total_weight_loss = 0
        total_mask_loss = 0
        total_img_loss = 0

        self.hypernet.cleanup()
        # bb()
        final_weight_dicts = self.hypernet.generate_weights(cond_embedding, warmup_lr=self.get_warmup_lr())

        for sid in range(BS):
            self.hypernet.activate_idx(sid)
            pred_weight_dicts = {k: v[sid : sid + 1] for k, v in final_weight_dicts[-1].items()}
            gt_weight_dicts = {k: v[sid : sid + 1] for k, v in batch["ngp_mlp_weights"].items()}
            for key in gt_weight_dicts.keys():
                if key in pred_weight_dicts:
                    total_weight_loss += F.l1_loss(pred_weight_dicts[key], gt_weight_dicts[key])

            rgb, acc, depth, target_s, mask_s = self.forward_poses(
                target_img[sid],
                target_pose[sid],
                mask_target[sid, :, 0].bool(),
                grid_weights[sid],
                full_scale=False,  # True
            )     

            mask_loss = img2l1(acc.reshape(mask_s.shape), mask_s.float())
            img_loss = img2l1(rgb, target_s[..., :3])
            total_mask_loss += mask_loss
            total_img_loss += img_loss

            psnr_mean = psnr(rgb, target_s[..., :3])
            psnr_list.append(psnr_mean.item())

        avg_weight_loss = total_weight_loss / BS
        avg_mask_loss = total_mask_loss / BS 
        avg_img_loss = total_img_loss / BS
        if self.args.img_loss_only:
            loss = avg_img_loss + avg_mask_loss * 0.05
        elif self.args.weight_loss_only:
            loss = avg_weight_loss
        else:
            loss = avg_weight_loss * self.args.weights_loss_weight + avg_img_loss  + avg_mask_loss * 0.05
        
        self.log_dict(
            {
                "train/lr": self.optimizers().param_groups[0]['lr'],
                "train/loss": loss,
                "train/weight_loss": avg_weight_loss,
                "train/mask_loss": avg_mask_loss,
                "train/img_loss": avg_img_loss,
                "train/psnr": np.mean(psnr_list),
            },
            on_step=True,
            sync_dist=True,
            prog_bar=True,
        )
        return loss


    def val_test_step(
        self, batch, batch_idx, mode, render_factor=1, dataloader_idx=None
    ):
        cond_img = batch["cond_img"]
        target_img = batch["target_img"]
        target_pose = batch["target_pose"]
        cond_pose = batch["cond_pose"]
        mask_cond = batch["mask_cond"]
        mask_target = batch["mask_target"]
        cond_indices = batch["cond_indices"]
        target_indices = batch["target_indices"]
        uid = batch["uid"]
        obj_id = batch["obj_id"]
        ngp_mlp_weights = batch["ngp_mlp_weights"]
        grid_weights = batch["grid_weights"]

        BS = cond_img.shape[0]
        with torch.no_grad():
            cond_embedding = self.encode(cond_img, cond_pose)
            psnr_list = []
            total_weight_loss = 0
            total_mask_loss = 0
            total_img_loss = 0
            vis_img_list = []
            self.hypernet.cleanup()

            final_weight_dicts = self.hypernet.generate_weights(cond_embedding, warmup_lr=self.get_warmup_lr())

            for sid in range(BS):  # todo: batch processing.
                self.hypernet.activate_idx(sid)

                pred_weight_dicts = {k: v[sid : sid + 1] for k, v in final_weight_dicts[-1].items()}
                gt_weight_dicts = {k: v[sid : sid + 1] for k, v in batch["ngp_mlp_weights"].items()}
                for key in gt_weight_dicts.keys():
                    if key in pred_weight_dicts:
                        total_weight_loss += F.l1_loss(pred_weight_dicts[key], gt_weight_dicts[key])

                rgb, acc, depth, target_s, mask_s = self.forward_poses(
                    target_img[sid],
                    target_pose[sid],
                    mask_target[sid, :, 0].bool(),
                    grid_weights[sid],
                    full_scale=True,
                    render_factor=render_factor,
                )
                if mode == "test":
                    vis_img = self.save_figures(
                        cond_img[sid], rgb, acc, depth, target_s, mask_s
                    )
                    vis_img_list.append(vis_img)
                elif mode == "val":
                    if sid < 4:
                        vis_img = self.save_figures(
                            cond_img[sid], rgb, acc, depth, target_s, mask_s
                        )
                        vis_img_list.append(vis_img)

                mask_loss = img2l1(acc.reshape(mask_s.shape), mask_s.float())
                img_loss = img2l1(rgb, target_s[..., :3])
                total_mask_loss += mask_loss
                total_img_loss += img_loss

                psnr_value = psnr(rgb, target_s[..., :3]).item()
                psnr_list.append(psnr_value)
                
                if mode == "test":
                    current_obj_id = obj_id[sid] if isinstance(obj_id, (list, tuple)) else obj_id[sid].item()
                    log_file_path = os.path.join(self.logger.log_dir, f"{mode}_obj_psnr_log.txt")
                    with open(log_file_path, "a") as f:
                        f.write(f"Step_{self.global_step:06d}_Batch_{batch_idx:04d}_Sample_{sid:02d}: obj_id={current_obj_id}, PSNR={psnr_value:.6f}\n")


            if mode == "val":
                if len(vis_img_list) > 0 and self.global_rank == 0 and batch_idx % 60 == 0:
                    vis_img = torchvision.utils.make_grid(vis_img_list, nrow=1)
                    torchvision.utils.save_image(
                        vis_img,
                        os.path.join(
                            self.logger.log_dir,
                            f"{mode}__results_STEP{self.global_step:06d}_ITER{batch_idx:04d}.png",
                        ),
                        nrow=1,
                    )
            elif mode == "test":
                vis_img = torchvision.utils.make_grid(vis_img_list, nrow=1)
                torchvision.utils.save_image(
                    vis_img,
                    os.path.join(
                        self.logger.log_dir, 
                        f"{mode}__results_STEP{self.global_step:06d}_ITER{batch_idx:04d}.png"
                    ),
                    nrow=1,
                )

        # Average losses over batch for logging
        avg_weight_loss = total_weight_loss / BS
        avg_mask_loss = total_mask_loss / BS 
        avg_img_loss = total_img_loss / BS

        if self.args.img_loss_only:
            loss = avg_img_loss + avg_mask_loss * 0.05
        elif self.args.weight_loss_only:
            loss = avg_weight_loss
        else:
            loss = avg_weight_loss * self.args.weights_loss_weight + avg_img_loss  + avg_mask_loss * 0.05

        self.log_dict(
            {
                f"{mode}/psnr": float(np.mean(psnr_list)),
                f"{mode}/loss": loss,
                f"{mode}/weight_loss": avg_weight_loss,
                f"{mode}/img_loss": avg_img_loss,
                f"{mode}/mask_loss": avg_mask_loss,
            },
            on_epoch=True,
            sync_dist=True,
        )

    def dataloader_idx_str(self, dataloader_idx):
        if dataloader_idx is None:
            return ""
        if dataloader_idx == 0:
            return "novel-target"
        if dataloader_idx == 1:
            return "novel-cond"
        if dataloader_idx == 2:
            return "novel-obj"
        if dataloader_idx == 3:
            return "novel-objvse"
        return ""

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "val", dataloader_idx=dataloader_idx
        )

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "test", dataloader_idx=dataloader_idx
        )
