# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # MIT License

    # Copyright (c) 2021 apchenstu

    # Permission is hereby granted, free of charge, to any person obtaining a copy
    # of this software and associated documentation files (the "Software"), to deal
    # in the Software without restriction, including without limitation the rights
    # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    # copies of the Software, and to permit persons to whom the Software is
    # furnished to do so, subject to the following conditions:

    # The above copyright notice and this permission notice shall be included in all
    # copies or substantial portions of the Software.

    # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    # SOFTWARE.

from audioop import minmax
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningModule, Trainer, loggers
from pytorch_lightning.loggers import WandbLogger

import os
import time
import numpy as np
import imageio
import lpips
from torchvision import transforms as T

from skimage.metrics import structural_similarity as ssim

from model.geo_reasoner import CasMVSNet
from model.self_attn_renderer import Renderer, Renderer_v1, Renderer_v1_mlp, Renderer_v2, Renderer_density_v1, Renderer_RGBresidual, Renderer_geneRGBsigma, Renderer_geneRGBsigma_dist, RendererStyle, RendererStyle2branch, RendererStyle_one
from utils.rendering import render_rays, sigma2weights
from utils.utils import (
    load_ckpt,
    init_log,
    get_rays_pts,
    SL1Loss,
    self_supervision_loss,
    img2mse,
    mse2psnr,
    acc_threshold,
    abs_error,
    visualize_depth,
    interpolate_3D,
    unified_focal_loss,
    rgb2ycbcr,
    alter_batch,
    inverse_sigmoid,
    depth_smoothness,
    VGG16_perceptual,
    interpolate_2D,
    delta_t_compute,
)
from utils.options import config_parser
from data.get_datasets import (
    get_training_dataset,
    get_finetuning_dataset,
    get_validation_dataset,
)
import math
import pickle
import matplotlib.pyplot as plt
import cv2
import copy
import torch.distributions as tdist

from PIL import Image
from torchvision import transforms as T
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

import random
import datetime

args = config_parser()

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
     print(f"set up random seed {seed}")
setup_seed(args.seed)

lpips_fn = lpips.LPIPS(net="vgg")

class GeoNeRF(LightningModule):
    def __init__(self, hparams):
        super(GeoNeRF, self).__init__()
        self.hparams.update(vars(hparams))
        os.makedirs(f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}", exist_ok=True)
        with open(f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/hyperparameters.txt", "w") as f:
            f.write(f"{self.hparams}")

        assert self.hparams.style_D == False
        self.wr_cntr = 0

        self.depth_loss = SL1Loss()
        if self.hparams.cas_confi:
            self.depth_loss_with_confi = SL1Loss(have_confi=True)
        self.learning_rate = hparams.lrate

        if self.hparams.use_midas:
            # self.midas = torch.hub.load("intel-isl/MiDaS", "DPT_Large").to(device).eval()
            self.midas = torch.hub.load("intel-isl/MiDaS", "MiDaS").cuda().eval()
            self.midas_transform = torch.hub.load("intel-isl/MiDaS", "transforms").default_transform

        if self.hparams.geonerfMDMM or self.hparams.contentFeature:
            from MDMM.model import MD_multi
            self.MDMM = MD_multi(args).cuda()
            if self.hparams.geonerfMDMM or 'contentFeature_resume' in self.hparams.expname:
                self.MDMM.resume(args.resume, train=False, wantDiscontent=True, wantDis1=self.hparams.style_D_mdmm)
            # freeze MDMM
            for name, param in self.MDMM.named_parameters():
                param.requires_grad = False
            if self.hparams.noFreezeContent:
                for name, param in self.MDMM.enc_c.named_parameters():
                    param.requires_grad = True
                

        # Create geometry_reasoner and renderer models
        if self.hparams.geoFeatComplete != "None" or self.hparams.weightedMeanVar:
            assert self.hparams.disocclude == True
        if self.hparams.texFeatComplete:
            assert self.hparams.disocclude == True and self.hparams.texFeat == True
        if self.hparams.pDensity_loss:
            assert self.hparams.pDensity == True 
        self.geo_reasoner = CasMVSNet(
            use_depth=self.hparams.use_depth, 
            use_disocclude=self.hparams.disocclude, geoFeatComplete=self.hparams.geoFeatComplete, 
            D01=self.hparams.Deither0or1, D_gauss = self.hparams.D_gauss,
            upperbound=self.hparams.upperbound, upperbound_noise=self.hparams.upperbound_noise, upperbound_gauss=self.hparams.upperbound_gauss,
            texFeatComplete=self.hparams.texFeatComplete, texFeat=self.hparams.texFeat, nb_views=self.hparams.nb_views,
            pDensity=self.hparams.pDensity,
            O_label=self.hparams.O_label_loss,
            P_constraint=self.hparams.P_constraint,
            unimvs=self.hparams.unimvs_loss,
            use_featSA=(self.hparams.renderer_geneRGBsigma or self.hparams.renderer_geneRGBsigma_dist),
            use_global_geoFeat=(self.hparams.renderer_geneRGBsigma or self.hparams.renderer_geneRGBsigma_dist),
            check_feat_mode=self.hparams.check_feat_mode,
            separate_occ_feat=self.hparams.separate_occ_feat,
            cas_confi=self.hparams.cas_confi,
            texFeat_woUnet=self.hparams.texFeat_woUnet,
            save_var_confi = self.hparams.save_var_confi,
            geonerfMDMM=self.hparams.geonerfMDMM,
            style3Dfeat=self.hparams.style3Dfeat,
            styleTwoBranch=self.hparams.styleTwoBranch,
            unparallExtract=self.hparams.unparallExtract,
            contentPyramid=self.hparams.contentPyramid,
            contentFeature=self.hparams.contentFeature,
        ).cuda()

        if self.hparams.learn_3dfeat_from_GT:
            self.geo_reasoner_teacher = CasMVSNet(
                use_depth=self.hparams.use_depth, 
                use_disocclude=self.hparams.disocclude, geoFeatComplete=self.hparams.geoFeatComplete, 
                D01=self.hparams.Deither0or1, D_gauss = self.hparams.D_gauss,
                upperbound=self.hparams.upperbound, upperbound_noise=self.hparams.upperbound_noise, upperbound_gauss=self.hparams.upperbound_gauss,
                texFeatComplete=self.hparams.texFeatComplete, texFeat=self.hparams.texFeat, nb_views=self.hparams.nb_views,
                pDensity=self.hparams.pDensity,
                O_label=self.hparams.O_label_loss,
                P_constraint=self.hparams.P_constraint,
                unimvs=self.hparams.unimvs_loss,
                use_featSA=(self.hparams.renderer_geneRGBsigma or self.hparams.renderer_geneRGBsigma_dist),
                use_global_geoFeat=(self.hparams.renderer_geneRGBsigma or self.hparams.renderer_geneRGBsigma_dist),
                check_feat_mode=self.hparams.check_feat_mode,
                separate_occ_feat=self.hparams.separate_occ_feat,
                is_teacher=True,
            ).cuda()
        
        if self.hparams.renderer_v1:
            self.renderer = Renderer_v1(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine
            ).cuda()
        elif self.hparams.renderer_v1_mlp:
            self.renderer = Renderer_v1_mlp(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine
            ).cuda()
        elif self.hparams.renderer_v2:
            self.renderer = Renderer_v2(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine
            ).cuda()
        elif self.hparams.renderer_density_v1:
            self.renderer = Renderer_density_v1(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine
            ).cuda()
        elif self.hparams.renderer_RGBresidual:
            self.renderer = Renderer_RGBresidual(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine
            ).cuda()
        elif self.hparams.renderer_geneRGBsigma:
            self.renderer = Renderer_geneRGBsigma(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine,
                use_attention_3d=self.hparams.attention_3d,
                output_gene=self.hparams.renderer_geneRGBsigma_loss,
            ).cuda()
        elif self.hparams.renderer_geneRGBsigma_dist:
            if self.hparams.sample: assert self.hparams.eval == True
            self.renderer = Renderer_geneRGBsigma_dist(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine,
                use_attention_3d=self.hparams.attention_3d,
                sample=self.hparams.sample,
                modify=4,
                test=self.hparams.eval,
            ).cuda()
        elif self.hparams.geonerfMDMM:
            if self.hparams.styleTwoBranch:
                if self.hparams.use_fourier_feature: assert self.hparams.phi_final_actv == 'sigmoid'
                self.renderer = RendererStyle2branch(
                    nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine,
                    gene_mask=self.hparams.gene_mask,
                    n_domain=self.hparams.num_domains,
                    style3Dfeat=self.hparams.style3Dfeat,
                    timephi=self.hparams.timephi,
                    weatherEmbedding=self.hparams.weatherEmbedding,
                    STB_ver=self.hparams.styleTwoBranch_ver,
                    wo_z=self.hparams.wo_z,
                    phi_final_actv=self.hparams.phi_final_actv,
                    use_fourier_feature=self.hparams.use_fourier_feature,
                    branch2_noPhi=self.hparams.branch2_noPhi,
                    z_dim=self.hparams.z_dim,
                    fourier_phi=self.hparams.fourier_phi,
                    phiNoCosSin=self.hparams.phi_noCosSin,
                    zInputStyle=self.hparams.zInputStyle,
                    update_z=self.hparams.update_z,
                    delta_t=self.hparams.delta_t,
                    delta_t_1x1=self.hparams.delta_t_1x1,
                    z_zInputStyle_Fuse=self.hparams.z_zInputStyle_Fuse,
                    rgb2t=self.hparams.rgb2t,
                ).cuda()
            else:
                self.renderer = RendererStyle_one(
                    nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine,
                    gene_mask=self.hparams.gene_mask,
                    n_domain=self.hparams.num_domains,
                    catDomain=self.hparams.catDomain,
                    catStyle=self.hparams.catStyle,
                    styleMLP_cls=self.hparams.styleMLP_cls,
                    styleLast=self.hparams.styleLast,
                    adainNormalize=self.hparams.adainNormalize,
                    style3Dfeat=self.hparams.style3Dfeat,
                    timephi=self.hparams.timephi,
                    weatherEmbedding=self.hparams.weatherEmbedding,
                    weatherEncode=self.hparams.weatherEncode,
                    weatherEncodeCls=self.hparams.weatherEncodeCls,
                    add_z=self.hparams.add_z,
                    zInputStyle=self.hparams.zInputStyle,
                    delta_t=self.hparams.delta_t,
                    delta_t_1x1=self.hparams.delta_t_1x1,
                ).cuda()
        else:
            self.renderer = Renderer(
                nb_samples_per_ray=hparams.nb_coarse + hparams.nb_fine, weightedMeanVar=hparams.weightedMeanVar,
                gene_mask=self.hparams.gene_mask,
                check_feat_mode=self.hparams.check_feat_mode,
            ).cuda()

        self.eval_metric = [0.01, 0.05, 0.1]

        self.automatic_optimization = False
        self.save_hyperparameters()

    def unpreprocess(self, data, shape=(1, 1, 3, 1, 1)):
        # to unnormalize image for visualization
        device = data.device
        mean = (
            torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225])
            .view(*shape)
            .to(device)
        )
        std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).view(*shape).to(device)

        return (data - mean) / std

    def prepare_data(self):
        ithaca_all = None
        if self.hparams.scene == "None" and not args.eval:  ## Generalizable
            only_dtu = self.hparams.upperbound or self.hparams.only_dtu
            self.train_dataset, self.train_sampler = get_training_dataset(self.hparams, downsample=self.hparams.downsample, llffdownsample=self.hparams.llffdownsample, only_dtu=only_dtu, only_llff=self.hparams.only_llff, only_ithaca=self.hparams.only_ithaca, use_far_view=self.hparams.far_view_loss, ithaca_all=ithaca_all)
            self.val_dataset = get_validation_dataset(self.hparams, downsample=self.hparams.downsample, use_far_view=self.hparams.test_far_view, ithaca_all=ithaca_all)
        elif args.eval:
            self.val_dataset = get_validation_dataset(self.hparams, downsample=self.hparams.downsample, use_far_view=self.hparams.test_far_view, ithaca_all=ithaca_all)

        else:  ## Fine-tune
            self.train_dataset, self.train_sampler = get_finetuning_dataset(
                self.hparams, downsample=self.hparams.downsample, use_far_view=self.hparams.far_view_loss, ithaca_all=ithaca_all
            )
            self.val_dataset = get_validation_dataset(self.hparams, downsample=self.hparams.downsample, use_far_view=self.hparams.test_far_view, ithaca_all=ithaca_all)

    def configure_optimizers(self):
        eps = 1e-5

        opt = torch.optim.Adam(
            list(self.geo_reasoner.parameters()) + list(self.renderer.parameters()),
            lr=self.learning_rate,
            betas=(0.9, 0.999),
        )
        sch = CosineAnnealingLR(opt, T_max=self.hparams.num_steps, eta_min=eps)

        return [opt], [sch]

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            sampler=self.train_sampler,
            shuffle=True if self.train_sampler is None else False,
            num_workers=8,
            batch_size=1,
            pin_memory=True,
        )

    def val_dataloader(self):
        return  DataLoader(
            self.val_dataset,
            shuffle=False,
            num_workers=1,
            batch_size=1,
            pin_memory=True,
        )

    def training_step(self, batch, batch_nb):
        if self.hparams.NVSThenStyle:
            if self.wr_cntr//2 < 10:
                for name, param in self.renderer.named_parameters():
                    if 'style' in name:
                        param.requires_grad = False
            else:
                for name, param in self.geo_reasoner.named_parameters():
                    param.requires_grad = False

                for name, param in self.renderer.named_parameters():
                    if 'style' in name:
                        param.requires_grad = True
                    else:
                        param.requires_grad = False

        def train_(batch, first_time=False, geo_reasoner_output=None, points_2d=None, for_novel_depth_loss=None):
            nb_views = self.hparams.nb_views
            H, W = batch["images"].shape[-2:]
            H, W = int(H), int(W)
            loss = 0
            
            if self.hparams.contentFeature:
                self.MDMM.eval()
                unpre_imgs = self.unpreprocess(batch["images"])
                mdmm_normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                mdmm_images = mdmm_normalize(unpre_imgs)

                _content_feat = []
                for v in range(nb_views):
                    _content_feat.append(self.MDMM.test_forward_random(mdmm_images[:,v], return_content=True))
                
                content_feats = {}
                for l in range(3):
                    content_feats[f"level_{l}"] = torch.cat([_content_feat[v][f"level_{l}"] for v in range(nb_views)], dim=0) # (V, c, h, w)

            if self.hparams.geonerfMDMM:
                self.MDMM.eval()

                unpre_imgs = self.unpreprocess(batch["images"])
                unpre_style_img = self.unpreprocess(batch["style_img"])
                mdmm_normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                mdmm_images = mdmm_normalize(unpre_imgs)
                mdmm_style_images = mdmm_normalize(unpre_style_img)[0]

                with torch.no_grad():
                    _content_feat = []
                    for v in range(nb_views):
                        _content_feat.append(self.MDMM.test_forward_random(mdmm_images[:,v], return_content=True))
                    
                    content_feats = {}
                    for l in range(3):
                        content_feats[f"level_{l}"] = torch.cat([_content_feat[v][f"level_{l}"] for v in range(nb_views)], dim=0) # (V, c, h, w)
                    # domain = random.randint(0,self.hparams.num_domains-1)
                    domain = batch["style_img_label"]
                    domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    domain_vec[:,domain] = 1
                    # style_novel_normalize = self.MDMM.test_forward_random(mdmm_images[:,-1], specify_domain=domain)[0][0]
                    # style_feat_domain = {'domain':domain_vec[0], 'style_feat':self.MDMM.z_random[0]} # (n_domain) (8)
                    style_novel_normalize = self.MDMM.test_forward_transfer(mdmm_images[:,-1], mdmm_style_images, domain_vec)[0] #(3, H, W)
                    style_feat_domain = {'domain':domain_vec[0], 'style_feat':self.MDMM.z_attr[0]} # (n_domain) (8)
                    style_novel = style_novel_normalize*0.5 + 0.5 # denormalize: range [0,1] # (3, H, W)

                    if self.hparams.adaptive_style_loss:
                        resize_216 = transforms.Resize((216,216), Image.BICUBIC)
                        style_novel_216 = resize_216(style_novel_normalize.unsqueeze(0)) # (1, 3, 216, 216)
                        _, pred_cls = self.MDMM.dis1.forward(style_novel_216)
                        # z_content = self.MDMM.enc_c.forward(style_novel_216)
                        # if self.hparams.content_level:
                        #     pred_cls = self.MDMM.disContent.forward(z_content["level_2"].detach())
                        # else:
                        #     pred_cls = self.MDMM.disContent.forward(z_content.detach())
                        loss_style_cls = self.MDMM.cls_loss(pred_cls, domain_vec)
            
            if geo_reasoner_output is None:
                ## Inferring Geometry Reasoner
                geo_reasoner_output = self.geo_reasoner(
                    imgs=batch["images"][:, :nb_views],
                    affine_mats=batch["affine_mats"][:, :nb_views],
                    affine_mats_inv=batch["affine_mats_inv"][:, :nb_views],
                    near_far=batch["near_fars"][:, :nb_views],
                    closest_idxs=batch["closest_idxs"][:, :nb_views],
                    gt_depths=batch["depths_aug"][:, :nb_views],
                    gt_depths_real=batch["depths"],
                    content_feats=content_feats if self.hparams.geonerfMDMM or self.hparams.contentFeature else None,
                    content_style_feat = {'content':content_feats, 'style':style_feat_domain} if self.hparams.geonerfMDMM and self.hparams.style3Dfeat else None
                )

            feats_vol, feats_fpn, depth_map, depth_values, other_kwarg = geo_reasoner_output

            if self.hparams.learn_3dfeat_from_GT:
                with torch.no_grad():
                    ## Inferring Geometry Reasoner
                    geo_reasoner_output_t = self.geo_reasoner_teacher(
                        imgs=batch["images"][:, :nb_views],
                        affine_mats=batch["affine_mats"][:, :nb_views],
                        affine_mats_inv=batch["affine_mats_inv"][:, :nb_views],
                        near_far=batch["near_fars"][:, :nb_views],
                        closest_idxs=batch["closest_idxs"][:, :nb_views],
                        gt_depths=batch["depths_aug"][:, :nb_views],
                        gt_depths_real=batch["depths"],
                    )

                    feats_vol_t, feats_fpn_t, depth_map_t, depth_values_t, other_kwarg_t = geo_reasoner_output_t

            ## Normalizing depth maps in NDC coordinate
            depth_map_norm = {}
            for l in range(3):
                depth_map_norm[f"level_{l}"] = (
                    depth_map[f"level_{l}"].detach() - depth_values[f"level_{l}"][:, :, 0]
                ) / (
                    depth_values[f"level_{l}"][:, :, -1]
                    - depth_values[f"level_{l}"][:, :, 0]
                )

            unpre_imgs = self.unpreprocess(batch["images"])

            if self.hparams.cycle_loss and not first_time: assert points_2d != None
            (
                pts_depth,
                rays_pts,
                rays_pts_ndc,
                rays_dir,
                rays_gt_rgb,
                rays_gt_depth,
                rays_pixs,
                rays_os,
                other_rays_output
            ) = get_rays_pts(
                H,
                W,
                batch["c2ws"],
                batch["w2cs"],
                batch["intrinsics"],
                batch["near_fars"],
                depth_values,
                self.hparams.nb_coarse,
                self.hparams.nb_fine,
                nb_views=nb_views,
                train=True,
                train_batch_size=self.hparams.batch_size,
                target_img=unpre_imgs[0, -1],
                target_depth=batch["depths_h"][0, -1],
                points_2d=points_2d,
                train_patch=self.hparams.train_patch,
                style_novel=style_novel if self.hparams.geonerfMDMM else None,
            )

            other_output_list = []
            if self.hparams.DT_loss:
                other_output_list += ['T']
            if self.hparams.consistent3d_loss:
                other_output_list += ['T', 'alpha', 'input_color', 'angle_cos', 'rgb', 'D', 'mask']
            if self.hparams.pDensity_loss:
                other_output_list += ['density']
            if self.hparams.pDensity_RGBloss:
                other_output_list += ['rgb']
            if self.hparams.density_01loss:
                other_output_list += ['T', 'alpha']
            if self.hparams.renderer_geneRGBsigma_dist:
                other_output_list += ['density']
            
            ## Rendering
            if self.hparams.add_z:
                if self.hparams.zInputStyle: # should be style_img style (train time) (since loss <-> input_img + style of style_img)
                    # input_domain = batch["input_img_label"]
                    # input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    # input_domain_vec[:,input_domain] = 1
                    # z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)

                    input_domain = batch["style_img_label"]
                    input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    input_domain_vec[:,input_domain] = 1
                    z, _ = self.MDMM.enc_a.forward(mdmm_style_images, input_domain_vec)

                     # style_img is night or zisInput
                    if (self.hparams.nightNoRemainCode and input_domain == 0) or self.hparams.zInputStyle_isInput:
                        _input_domain = batch["input_img_label"]
                        _input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                        _input_domain_vec[:,_input_domain] = 1
                        z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], _input_domain_vec)

                    z = self.renderer.style2remain(z)
                else:
                    dist = tdist.Normal(torch.tensor([0.]), torch.tensor([self.hparams.gauss_var]))
                    z = dist.sample((self.hparams.z_dim,)).reshape(1,self.hparams.z_dim).cuda()
            else:
                z = None

            rendered_rgb, rendered_depth, renderer_other_output = render_rays(
                c2ws=batch["c2ws"][0, :nb_views],
                rays_pts=rays_pts,
                rays_pts_ndc=rays_pts_ndc,
                pts_depth=pts_depth,
                rays_dir=rays_dir,
                feats_vol=feats_vol,
                feats_fpn=feats_fpn[:, :nb_views],
                imgs=unpre_imgs[:, :nb_views],
                depth_map_norm=depth_map_norm,
                renderer_net=self.renderer,
                other_kwarg=other_kwarg,
                other_output=other_output_list,
                have_dist=self.hparams.renderer_geneRGBsigma_dist,
                have_gene=self.hparams.renderer_geneRGBsigma_loss,
                use_att_3d=self.hparams.attention_3d,
                two_feats=(self.hparams.check_feat_mode=="2" or self.hparams.separate_occ_feat),
                content_style_feat = {'content':content_feats, 'style':style_feat_domain} if self.hparams.geonerfMDMM else None,
                input_z=z,
            )

            if self.hparams.inputWhole_cs_loss or self.hparams.input_style_loss or self.hparams.timephi_loss or self.hparams.styleTwoBranch:
                ## render input stylized image
                rays_pts_ndc_inputView = get_rays_pts(
                                            H,
                                            W,
                                            batch["c2ws"],
                                            batch["w2cs"],
                                            batch["intrinsics"],
                                            batch["near_fars"],
                                            depth_values,
                                            self.hparams.nb_coarse,
                                            self.hparams.nb_fine,
                                            nb_views=nb_views,
                                            train=True,
                                            train_batch_size=self.hparams.batch_size,
                                            target_img=unpre_imgs[0, -1],
                                            target_depth=batch["depths_h"][0, -1],
                                            points_2d=points_2d,
                                            train_patch=self.hparams.train_patch,
                                            style_novel=style_novel if self.hparams.geonerfMDMM else None,
                                            inputWhole_cs_loss=True,
                                            pred_depth_map=depth_map["level_0"][0,0], # depth_map: (1,V,H,W)
                                        )[2]

                if not self.hparams.styleTwoBranch:
                    feat_to_cat = []
                    style, domain = style_feat_domain["style_feat"], style_feat_domain["domain"], 
                    for l in range(3):
                        # content_l = content_feats[f"level_{l}"].unsqueeze(0)
                        content_l = other_kwarg['feats'][f"level_{l}"].unsqueeze(0)
                        ray_content_l, _, _ = interpolate_2D(
                            content_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                        )
                        feat_to_cat.append(ray_content_l.unsqueeze(1))

                    N, S = rays_pts_ndc_inputView["level_0"].shape[:2]
                    style = style.expand(N, 1, -1)
                    domain = domain.expand(N, 1, -1)
                    feat_to_cat.append(style)
                    feat_to_cat.append(domain)
                    interpolated_feats_one = torch.cat(feat_to_cat, dim=-1).unsqueeze(2) # (N,1,1,c) ; normal interpolated_feats: (N,S,V,c)

                    if self.hparams.add_z:
                        if self.hparams.zInputStyle:
                            # input_domain = batch["input_img_label"]
                            # input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                            # input_domain_vec[:,input_domain] = 1
                            # z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)

                            input_domain = batch["style_img_label"]
                            input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                            input_domain_vec[:,input_domain] = 1
                            z, _ = self.MDMM.enc_a.forward(mdmm_style_images, input_domain_vec)

                            if (self.hparams.nightNoRemainCode and input_domain == 0) or self.hparams.zInputStyle_isInput: # style_img is night
                                _input_domain = batch["input_img_label"]
                                _input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                                _input_domain_vec[:,_input_domain] = 1
                                z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], _input_domain_vec)

                            z = self.renderer.style2remain(z)
                        else:
                            dist = tdist.Normal(torch.tensor([0.]), torch.tensor([self.hparams.gauss_var]))
                            z = dist.sample((self.hparams.z_dim,)).reshape(1,self.hparams.z_dim).cuda()
                    else:
                        z = None
                    _, style_rgb_ray = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, z=z)
                    stylized_input = style_rgb_ray.reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
                    stylized_input_mdmm = mdmm_normalize(stylized_input)
                else:
                    feat_to_cat = []
                    style, domain = style_feat_domain["style_feat"], style_feat_domain["domain"]
                    common_feat, special_feat = other_kwarg['common_feat'], other_kwarg['special_feat']
                    for l in range(3):
                        common_feat_l = common_feat[f"level_{l}"].unsqueeze(0)
                        ray_common_l, _, _ = interpolate_2D(
                            common_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                        )
                        feat_to_cat.append(ray_common_l.unsqueeze(1))

                    for l in range(3):
                        special_feat_l = special_feat[f"level_{l}"].unsqueeze(0)
                        ray_special_l, _, _ = interpolate_2D(
                            special_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                        )
                        feat_to_cat.append(ray_special_l.unsqueeze(1))
                    
                    N, S = rays_pts_ndc_inputView["level_0"].shape[:2]
                    style = style.expand(N, 1, -1)
                    domain = domain.expand(N, 1, -1)
                    feat_to_cat.append(style)
                    feat_to_cat.append(domain)
                    interpolated_feats_one = torch.cat(feat_to_cat, dim=-1).unsqueeze(2) # (N,1,1,c) ; normal interpolated_feats: (N,S,V,c)
                    
                    if self.hparams.zInputStyle:
                        # input_domain = batch["input_img_label"]
                        # input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                        # input_domain_vec[:,input_domain] = 1
                        # z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)

                        input_domain = batch["style_img_label"]
                        input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                        input_domain_vec[:,input_domain] = 1
                        z, _ = self.MDMM.enc_a.forward(mdmm_style_images, input_domain_vec) # z=style

                        if (self.hparams.nightNoRemainCode and input_domain == 0) or self.hparams.zInputStyle_isInput: # style_img is night
                            _input_domain = batch["input_img_label"]
                            _input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                            _input_domain_vec[:,_input_domain] = 1
                            z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], _input_domain_vec)

                        z = self.renderer.style2remain(z)
                    else:
                        dist = tdist.Normal(torch.tensor([0.]), torch.tensor([self.hparams.gauss_var]))
                        z = dist.sample((self.hparams.z_dim,)).reshape(1,self.hparams.z_dim).cuda()

                    if self.hparams.z_zInputStyle_Fuse:
                        z = self.renderer.z2commonSpace(z)
                    _, style_rgb_ray = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, return_what_style='both', z=z)
                    # common
                    stylized_input = style_rgb_ray['c'].reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
                    stylized_input_mdmm = mdmm_normalize(stylized_input)
                    # common+special
                    stylized_input_cs = style_rgb_ray['c+s'].reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
                    stylized_input_cs_mdmm = mdmm_normalize(stylized_input_cs)

                    if self.hparams.z_zInputStyle_Fuse:
                        dist = tdist.Normal(torch.tensor([0.]), torch.tensor([self.hparams.gauss_var]))
                        z_N = dist.sample((self.hparams.z_dim,)).reshape(1,self.hparams.z_dim).cuda()
                        if self.hparams.z_zInputStyle_Fuse:
                            z_N = self.renderer.z2commonSpace(z_N)

                        _, style_rgb_ray_zN = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, return_what_style='both', z=z_N)
                        # common+special
                        stylized_input_cs_zN = style_rgb_ray_zN['c+s'].reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)

                if self.hparams.rgb2t:
                    cos_t, sin_t = math.cos(self.renderer.t), math.sin(self.renderer.t)
                    pred_t = self.renderer.rgb2phi_net(style_rgb_ray['c+s']).squeeze()
                    cos_predt, sin_predt = pred_t[:,0], pred_t[:,1]

                    rgb2t_loss_val = torch.mean((cos_predt - cos_t) ** 2 + (sin_predt - sin_t) ** 2)
                    loss += self.hparams.rgb2t_loss_lamb * rgb2t_loss_val
                    self.log("train/rgb2t_loss", rgb2t_loss_val.item(), prog_bar=True)


                if self.hparams.delta_t:
                    t_rand = torch.rand(1).item()*(2*math.pi)
                    if self.hparams.styleTwoBranch:
                        _, style_rgb_ray_tRand = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, return_what_style='both', z=z, input_phi=t_rand)
                        # common+special
                        stylized_input_cs_tRand = style_rgb_ray_tRand['c+s'].reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
                        
                        delta_cos = math.cos(self.renderer.t) - math.cos(t_rand)
                        delta_sin = math.sin(self.renderer.t) - math.sin(t_rand)

                        delta_est = self.renderer.delta_t_estimator(torch.cat((stylized_input_cs,stylized_input_cs_tRand),dim=1)).squeeze()
                    else:
                        _, style_rgb_ray_tRand = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, z=z, input_phi=t_rand)
                        # common+special
                        stylized_input_tRand = style_rgb_ray_tRand.reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
                        
                        delta_cos = math.cos(self.renderer.t) - math.cos(t_rand)
                        delta_sin = math.sin(self.renderer.t) - math.sin(t_rand)

                        delta_est = self.renderer.delta_t_estimator(torch.cat((stylized_input,stylized_input_tRand),dim=1)).squeeze()

                    delta_t_loss = (delta_cos-delta_est[0])**2 + (delta_sin-delta_est[1])**2
                    loss += self.hparams.delta_t_loss_lamb * delta_t_loss
                    self.log("train/delta_t_loss", delta_t_loss.item(), prog_bar=True)
                
                if self.hparams.t0_rec_loss:
                    if self.hparams.styleTwoBranch:
                        feat_to_cat = []
                        domain_t0 = batch["input_img_label"]
                        domain_vec_t0 = torch.zeros((1,self.hparams.num_domains)).cuda()
                        domain_vec_t0[:,domain_t0] = 1
                        style_t0, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], domain_vec_t0)
                        z_t0 = self.renderer.style2remain(style_t0)
                        if self.hparams.z_zInputStyle_Fuse:
                            z_t0 = self.renderer.z2commonSpace(z_t0)
                        style_t0, domain_t0 = style_t0[0], domain_vec_t0[0]

                        common_feat, special_feat = other_kwarg['common_feat'], other_kwarg['special_feat']
                        for l in range(3):
                            common_feat_l = common_feat[f"level_{l}"].unsqueeze(0)
                            ray_common_l, _, _ = interpolate_2D(
                                common_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                            )
                            feat_to_cat.append(ray_common_l.unsqueeze(1))

                        for l in range(3):
                            special_feat_l = special_feat[f"level_{l}"].unsqueeze(0)
                            ray_special_l, _, _ = interpolate_2D(
                                special_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                            )
                            feat_to_cat.append(ray_special_l.unsqueeze(1))
                        
                        N, S = rays_pts_ndc_inputView["level_0"].shape[:2]
                        style_t0 = style_t0.expand(N, 1, -1)
                        domain_t0 = domain_t0.expand(N, 1, -1)
                        feat_to_cat.append(style_t0)
                        feat_to_cat.append(domain_t0)
                        interpolated_feats_one = torch.cat(feat_to_cat, dim=-1).unsqueeze(2) # (N,1,1,c) ; normal interpolated_feats: (N,S,V,c)
                        
                        _, style_rgb_ray_t0 = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, return_what_style='both', z=z_t0)
                        stylized_input_cs_t0 = style_rgb_ray_t0['c+s'].reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
                    else:
                        feat_to_cat = []
                        domain_t0 = batch["input_img_label"]
                        domain_vec_t0 = torch.zeros((1,self.hparams.num_domains)).cuda()
                        domain_vec_t0[:,domain_t0] = 1
                        style_t0, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], domain_vec_t0)
                        z_t0 = self.renderer.style2remain(style_t0)
                        style_t0, domain_t0 = style_t0[0], domain_vec_t0[0]

                        for l in range(3):
                            # content_l = content_feats[f"level_{l}"].unsqueeze(0)
                            content_l = other_kwarg['feats'][f"level_{l}"].unsqueeze(0)
                            ray_content_l, _, _ = interpolate_2D(
                                content_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                            )
                            feat_to_cat.append(ray_content_l.unsqueeze(1))

                        N, S = rays_pts_ndc_inputView["level_0"].shape[:2]
                        style_t0 = style_t0.expand(N, 1, -1)
                        domain_t0 = domain_t0.expand(N, 1, -1)
                        feat_to_cat.append(style_t0)
                        feat_to_cat.append(domain_t0)
                        interpolated_feats_one = torch.cat(feat_to_cat, dim=-1).unsqueeze(2) # (N,1,1,c) ; normal interpolated_feats: (N,S,V,c)
                        _, style_rgb_ray_t0 = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, z=z_t0)
                        stylized_input_cs_t0 = style_rgb_ray_t0.reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)

                    t0_rec_loss = img2mse(stylized_input_cs_t0, unpre_imgs[:,0])
                    loss += self.hparams.t0_rec_lamb * t0_rec_loss
                    self.log("train/t0_rec_loss", t0_rec_loss.item(), prog_bar=True)

                if self.hparams.t_regularization:
                    domain_t0 = batch["input_img_label"]
                    domain_ref = batch["style_img_label"]

                    domain_vec_t0 = torch.zeros((1,self.hparams.num_domains)).cuda()
                    domain_vec_t0[:,domain_t0] = 1
                    style_t0, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], domain_vec_t0)
                    t0 = self.renderer.style2phi(style_t0)

                    delta_t = delta_t_compute(t0, t_rand)
                    t_reg = math.exp(-5*delta_t)
                    if domain_t0 == domain_ref:
                        t_reg_lamb = self.hparams.t_regularization_lamb * 0.01
                    else:
                        t_reg_lamb = self.hparams.t_regularization_lamb
                    # t_reg = -delta_t.item() + 1
                    if domain_t0 == domain_ref:
                        t_reg_lamb = self.hparams.t_regularization_lamb
                    else:
                        t_reg_lamb = self.hparams.t_regularization_lamb * 5
                    loss += t_reg_lamb * t_reg
                    self.log("train/t_reg_loss", t_reg, prog_bar=True)

                if self.hparams.timephi_loss:
                    ori_phi = self.renderer.style2phi(style_feat_domain["style_feat"])
                    style_input_style_mu, _ = self.MDMM.enc_a.forward(stylized_input_mdmm, domain_vec)
                    phi_prime = self.renderer.style2phi(style_input_style_mu)
                    phi_loss = img2mse(ori_phi, phi_prime)
                    if batch["style_img_label"] == 0: # night
                        phi_loss += img2mse(ori_phi, math.pi)
                    
                    loss += self.hparams.phi_loss_lamb * phi_loss
                    self.log("train/phi_loss", phi_loss.item(), prog_bar=True)
                
            if self.hparams.weatherEncodeCls:
                weather_pred = self.renderer.weather_pred
                domain = batch["style_img_label"]-1
                gt_domain_vec = torch.zeros((1,self.hparams.num_domains-1)).cuda()
                if domain >= 0:
                    gt_domain_vec[:,domain] = 1
                else:
                    gt_domain_vec[:,random.randint(0, 3)] = 1
                gt_domain_vec = gt_domain_vec.repeat(weather_pred.shape[0],1)
                CE_loss = torch.nn.CrossEntropyLoss()
                weather_pred_loss = CE_loss(weather_pred, gt_domain_vec)

                loss += self.hparams.weatherPred_lamb * weather_pred_loss
                self.log("train/weather_pred_loss", weather_pred_loss.item(), prog_bar=True)
            
            if self.hparams.inputWhole_cs_loss:
                cs_loss_val = 0
                style_input_content = self.MDMM.enc_c.forward(stylized_input_mdmm)
                style_input_style_mu, _ = self.MDMM.enc_a.forward(stylized_input_mdmm, domain_vec)
                for l in range(3):
                    cs_loss_val += self.hparams.ICSLoss_C_lamb * img2mse(style_input_content[f"level_{l}"], content_feats[f"level_{l}"][0:1])
                cs_loss_val += self.hparams.ICSLoss_S_lamb * img2mse(style_input_style_mu, self.MDMM.mu)

                loss += cs_loss_val
                self.log("train/ICS_loss", cs_loss_val.item(), prog_bar=True)

            if self.hparams.input_style_loss:
                style_input_normalize = self.MDMM.test_forward_transfer(mdmm_images[:,0], mdmm_style_images, domain_vec) #(1, 3, H, W)
                style_input_psgt = style_input_normalize*0.5 + 0.5 # denormalize: range [0,1] # (3, H, W)

                style_mse_loss = img2mse(stylized_input, style_input_psgt)
                if self.hparams.adaptive_style_loss:
                    # adaptive_weight = torch.exp(-loss_style_cls)
                    adaptive_weight = torch.exp(-3*loss_style_cls)
                    loss += self.hparams.style_lamb * adaptive_weight * style_mse_loss
                else:
                    loss += self.hparams.style_lamb * style_mse_loss
                self.log("train/input_style_mse_loss", style_mse_loss.item(), prog_bar=True)
            
            if self.hparams.STB_style_loss:
                assert self.hparams.styleTwoBranch == True
                style_input_cs_style_mu, _ = self.MDMM.enc_a.forward(stylized_input_cs_mdmm, domain_vec)
                STB_loss_val = img2mse(style_input_cs_style_mu, self.MDMM.z_attr)
                
                loss += self.hparams.STB_loss_lamb * STB_loss_val
                self.log("train/STB_loss", STB_loss_val.item(), prog_bar=True)
            
            if self.hparams.style_D_mdmm:
                assert self.hparams.styleTwoBranch == True
                resize_216 = transforms.Resize((216,216), Image.BICUBIC)
                pred_fake, pred_fake_cls = self.MDMM.dis1.forward(resize_216(stylized_input_cs_mdmm))
                loss_G_GAN = 0
                BCEwithlogit = torch.nn.BCEWithLogitsLoss()
                for out_a in pred_fake:
                    all_ones = torch.ones_like(out_a).cuda()
                    loss_G_GAN += BCEwithlogit(out_a, all_ones)
                
                loss += self.hparams.style_D_mdmm_lamb * loss_G_GAN
                self.log("train/style_D_mdmm_loss", loss_G_GAN.item(), prog_bar=True)
            
            if self.hparams.branch2_cycle_loss:
                assert self.hparams.styleTwoBranch == True
                style_input_content = self.MDMM.enc_c.forward(stylized_input_cs_mdmm)
                _domain = batch["input_img_label"]
                _domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                _domain_vec[:,_domain] = 1
                input_style_mu, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], _domain_vec)
                
                cycle_output_normalize = self.MDMM.gen.forward(style_input_content, input_style_mu, _domain_vec)
                cycle_output = cycle_output_normalize*0.5 + 0.5 # denormalize: range [0,1] # (3, H, W)

                branch2_cycle_loss_val = img2mse(cycle_output, unpre_imgs[:,0])
                loss += self.hparams.branch2_cycle_lamb * branch2_cycle_loss_val
                self.log("train/branch2_cycle_loss", branch2_cycle_loss_val.item(), prog_bar=True)
                
   
            if self.hparams.STB_style_mse_loss:
                assert self.hparams.styleTwoBranch == True
                style_input_normalize = self.MDMM.test_forward_transfer(mdmm_images[:,0], mdmm_style_images, domain_vec) #(1, 3, H, W)
                style_input_psgt = style_input_normalize*0.5 + 0.5 # denormalize: range [0,1] # (3, H, W)
                
                STB_style_mse_loss_val = img2mse(stylized_input_cs, style_input_psgt)
                
                loss += self.hparams.style_lamb * STB_style_mse_loss_val
                self.log("train/STB_mse_loss", STB_style_mse_loss_val.item(), prog_bar=True)
            
            if self.hparams.z_zInputStyle_Fuse:
                STB_style_mse_loss_z_val = img2mse(stylized_input_cs_zN, style_input_psgt)
                
                loss += self.hparams.style_lamb * STB_style_mse_loss_z_val
                self.log("train/STB_mse_loss_z", STB_style_mse_loss_z_val.item(), prog_bar=True)

            if self.hparams.input_rec_loss:
                assert self.hparams.styleTwoBranch == True
                feat_to_cat = []
                _domain = batch["input_img_label"]
                _domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                _domain_vec[:,_domain] = 1
                _style, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], _domain_vec)
                common_feat, special_feat = other_kwarg['common_feat'], other_kwarg['special_feat']
                for l in range(3):
                    common_feat_l = common_feat[f"level_{l}"].unsqueeze(0)
                    ray_common_l, _, _ = interpolate_2D(
                        common_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                    )
                    feat_to_cat.append(ray_common_l.unsqueeze(1))

                for l in range(3):
                    special_feat_l = special_feat[f"level_{l}"].unsqueeze(0)
                    ray_special_l, _, _ = interpolate_2D(
                        special_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                    )
                    feat_to_cat.append(ray_special_l.unsqueeze(1))
                
                N, S = rays_pts_ndc_inputView["level_0"].shape[:2]
                _style = _style.expand(N, 1, -1)
                _domain_vec = _domain_vec.expand(N, 1, -1)
                feat_to_cat.append(_style)
                feat_to_cat.append(_domain_vec)
                interpolated_feats_one = torch.cat(feat_to_cat, dim=-1).unsqueeze(2) # (N,1,1,c) ; normal interpolated_feats: (N,S,V,c)
                
                if self.hparams.zInputStyle:
                    # input_domain = batch["input_img_label"]
                    # input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    # input_domain_vec[:,input_domain] = 1
                    # z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)

                    input_domain = batch["style_img_label"]
                    input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    input_domain_vec[:,input_domain] = 1
                    z, _ = self.MDMM.enc_a.forward(mdmm_style_images, input_domain_vec)

                    if (self.hparams.nightNoRemainCode and input_domain == 0) or self.hparams.zInputStyle_isInput: # style_img is night
                        _input_domain = batch["input_img_label"]
                        _input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                        _input_domain_vec[:,_input_domain] = 1
                        z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], _input_domain_vec)

                    z = self.renderer.style2remain(z)
                else:
                    dist = tdist.Normal(torch.tensor([0.]), torch.tensor([self.hparams.gauss_var]))
                    z = dist.sample((self.hparams.z_dim,)).reshape(1,self.hparams.z_dim).cuda()

                if self.hparams.z_zInputStyle_Fuse:
                    z = self.renderer.z2commonSpace(z)
                _, style_rgb_ray_input = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, return_what_style='both', z=z)
                # common+special
                stylized_input_cs_input = style_rgb_ray_input['c+s'].reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
                
                input_rec_loss_val = img2mse(stylized_input_cs_input, unpre_imgs[:, 0])
                
                loss += self.hparams.input_rec_lamb * input_rec_loss_val
                self.log("train/input_rec_loss", input_rec_loss_val.item(), prog_bar=True)

            if self.hparams.DT_loss or self.hparams.O_constraint:
                # with torch.no_grad():
                _novel_disocc, novel_depth_values, novel_feat_vol = self.geo_reasoner.output_novel_disocc_func(
                    imgs=batch["images"][:, :],
                    affine_mats=batch["affine_mats"][:, :],
                    affine_mats_inv=batch["affine_mats_inv"][:, :],
                    near_far=batch["near_fars"][:, :],
                    closest_idxs=batch["closest_idxs"][:, :],
                    gt_depths=batch["depths_aug"][:, :],
                    gt_depths_real=batch["depths"],
                )
                novel_pts_depth, novel_rays_pts, novel_rays_pts_ndc, *_ = get_rays_pts(
                    H, W, batch["c2ws"][:,nb_views-1:], batch["w2cs"][:,nb_views-1:], batch["intrinsics"][:,nb_views-1:], batch["near_fars"][:,nb_views-1:], novel_depth_values,
                    self.hparams.nb_coarse, self.hparams.nb_fine, nb_views=1, train=True, train_batch_size=self.hparams.batch_size,
                    target_img=unpre_imgs[0, -1], target_depth=batch["depths_h"][0, -1],
                )
                novel_disocc_0 = interpolate_3D(_novel_disocc["level_0"][:, 0], novel_rays_pts_ndc["level_0"][:, :, 0])

            if self.hparams.DT_loss:
                DT_loss_val = img2mse(novel_disocc_0, renderer_other_output["T"])
                gradual = 0.1*math.exp(self.global_step/self.hparams.num_steps-1)
                loss += gradual*self.hparams.DT_loss_lamb * DT_loss_val
                self.log("train/DT_loss", DT_loss_val.item(), prog_bar=True)
                
            if self.hparams.O_constraint:
                _, _, novel_feat_vol_second = self.geo_reasoner.output_novel_disocc_func(
                    imgs=batch["images"][:, :],
                    affine_mats=batch["affine_mats"][:, :],
                    affine_mats_inv=batch["affine_mats_inv"][:, :],
                    near_far=batch["near_fars"][:, :],
                    closest_idxs=batch["second_closest_idxs"][:, :],
                    gt_depths=batch["depths_aug"][:, :]
                )
                O_loss = 0
                for l in range(3):
                    O_loss += (1/3)*img2mse(novel_feat_vol[f"level_{l}"].detach(), novel_feat_vol_second[f"level_{l}"])
                loss += self.hparams.O_loss_lamb * O_loss
                self.log("train/O_loss", O_loss.item(), prog_bar=True)

            if self.hparams.consistent3d_loss:
                diff_thres = 0.005
                D_thres = 0.9
                consist3d_loss_val = torch.tensor(0.).cuda()
                consist3d_g_loss_val, consist3d_rgb_loss_val, consist3d_coord_loss_val = torch.tensor(0.).cuda(), torch.tensor(0.).cuda(), torch.tensor(0.).cuda()
                # rendered_depth.shape:(1024) ; pts_depth.shape:(1024, 128)
                diff, depth_idx = torch.min(torch.abs((pts_depth-rendered_depth.detach().unsqueeze(-1).repeat(1, pts_depth.shape[1]))), dim=1)

                # T
                surface_T = renderer_other_output['T'].gather(1, depth_idx.unsqueeze(1))[diff <= diff_thres]
                consist3d_g_loss_val += img2mse(surface_T, 1.)
                # alpha
                surface_alpha = renderer_other_output['alpha'].gather(1, depth_idx.unsqueeze(1))[diff <= diff_thres]
                consist3d_g_loss_val += img2mse(surface_alpha, 1.)
                
                # RGB
                # surface_rgb = renderer_other_output['rgb'].gather(1, depth_idx.reshape(-1,1,1).repeat(1,1,3))[diff <= diff_thres].squeeze()
                # surface_gt_rgb = rays_gt_rgb[diff <= diff_thres]
                # consist3d_rgb_loss_val += img2mse(surface_rgb, surface_gt_rgb)
                
                consist3d_loss_val = self.hparams.consist3d_g_loss_lamb*consist3d_g_loss_val + self.hparams.consist3d_rgb_loss_lamb*consist3d_rgb_loss_val + self.hparams.consist3d_coord_loss_lamb*consist3d_coord_loss_val
                loss += consist3d_loss_val
                self.log("train/consist3d_loss", consist3d_loss_val.item(), prog_bar=True)

            if self.hparams.pDensity_loss:
                p = other_kwarg["pDensity"]
                pDensity_loss_val = 0
                for i in range(nb_views):
                    for l in range(3):
                        p_sample_pts = interpolate_3D(
                            p[f"level_{l}"][:, i], rays_pts_ndc[f"level_{l}"][:, :, i]
                        )
                        with torch.no_grad():
                            in_mask = (rays_pts_ndc[f"level_{l}"][:, :, i] > -1.0) * (rays_pts_ndc[f"level_{l}"][:, :, i] < 1.0)
                            in_mask = (in_mask[..., 0]*in_mask[..., 1]*in_mask[..., 2]).float()
                        pDensity_loss_val += torch.mean((in_mask * (p_sample_pts - renderer_other_output["density"].detach())) ** 2)

                pDensity_loss_val *= 1/(nb_views*3)
                loss += self.hparams.pDensity_loss_lamb*pDensity_loss_val
                self.log("train/pDensity_loss", pDensity_loss_val.item(), prog_bar=True)
            
            if self.hparams.pDensity_RGBloss:
                p = other_kwarg["pDensity"]
                nb_pDensity_RGBloss = 0
                pDensity_RGBloss_val = 0
                for i in range(nb_views):
                    for l in range(3):
                        p_sample_pts = interpolate_3D(
                            p[f"level_{l}"][:, i], rays_pts_ndc[f"level_{l}"][:, :, i]
                        )
                        with torch.no_grad():
                            in_mask = (rays_pts_ndc[f"level_{l}"][:, :, i] > -1.0) * (rays_pts_ndc[f"level_{l}"][:, :, i] < 1.0)
                            in_mask = (in_mask[..., 0]*in_mask[..., 1]*in_mask[..., 2]).float()
                        novel_weights_mvs = sigma2weights(in_mask * p_sample_pts).unsqueeze(-1)
                        rgb_mvs = torch.sum(novel_weights_mvs * renderer_other_output["rgb"], -2)
                        rgb_mask = (rgb_mvs - rays_gt_rgb).abs().mean(dim=-1) < 0.02
                        nb_pDensity_RGBloss += torch.sum(rgb_mask).detach()
                        pDensity_RGBloss_val += torch.mean((rgb_mask.unsqueeze(-1) * (rgb_mvs - rays_gt_rgb)) ** 2) * 2 ** (1 - l)
                
                pDensity_RGBloss_val *= 1/(nb_views*3)
                gradual = math.exp(self.global_step/self.hparams.num_steps-1)
                loss += gradual*self.hparams.pDensity_RGBloss_lamb*pDensity_RGBloss_val
                self.log("train/pD_RGBloss", pDensity_RGBloss_val.item(), prog_bar=True)
                self.log("train/nb_pD_RGBloss", nb_pDensity_RGBloss.item(), prog_bar=True)

            if self.hparams.O_label_loss:
                O_label_val = other_kwarg['O_label']
                O_label_loss_val = 0
                for v in range(3): # view
                    for l in range(3): # level
                        n_depth = depth_values[f"level_{l}"].shape[2]
                        # print(depth_map[f"level_{l}"][v:v+1,...].unsqueeze(0).shape, n_depth)
                        _depth_map = depth_map[f"level_{l}"][:,v] #(1,H,W)
                        repeat_depth_l = _depth_map.repeat(n_depth, 1, 1)
                        _depth_val = depth_values[f"level_{l}"][0,v,...] #(n_depth, H, W)
                        d_diff = _depth_val - repeat_depth_l
                        surface_mask = (d_diff.abs() <= 1e-5).float()
                        air_mask = (d_diff <= -1e-5).float()
                        O_label_loss_val += torch.mean((surface_mask * (O_label_val[f"level_{l}"].squeeze()[v,...] - 1)) ** 2) * 2 ** (1 - l)
                        O_label_loss_val += torch.mean((air_mask * (O_label_val[f"level_{l}"].squeeze()[v,...] - 0)) ** 2) * 2 ** (1 - l)
                
                loss += self.hparams.O_label_loss_lamb * O_label_loss_val / (3*3)
                self.log("train/O_label_loss", O_label_loss_val.item(), prog_bar=True)
            
            if self.hparams.O_label_01loss:
                O_label_val = other_kwarg['O_label']
                O_label_01loss_val = 0
                for v in range(3): # view
                    for l in range(3): # level
                        O_label_01loss_val += -torch.mean((O_label_val[f"level_{l}"].squeeze()[v,...] - 0.5) ** 2) * 2 ** (1 - l)
                
                loss += self.hparams.O_label_01loss_lamb * O_label_01loss_val / (3*3)
                self.log("train/O_label_01loss", O_label_01loss_val.item(), prog_bar=True)


            # if self.hparams.P_constraint and (not self.hparams.upperbound) and isinstance(batch["depths"], dict):
            if self.hparams.P_constraint and isinstance(batch["depths"], dict):
                P_loss = 0
                depth_probs = other_kwarg['depth_probs']
                for v in range(nb_views):
                    for l in range(3):
                        n_depth = depth_values[f"level_{l}"].shape[2]
                        depth_prob_l = F.softmax(depth_probs[f"level_{l}"].squeeze()[v:v+1, ...], dim=1) # (1, d, h, w)
                        depth_gt_l = batch["depths"][f"level_{l}"][:, v:v+1]
                        depth_values_l = depth_values[f"level_{l}"][:,v,...]
                        repeat_gt_d = depth_gt_l.repeat(1, n_depth, 1, 1)
                        d_diff = (depth_values_l - repeat_gt_d)
                        gt_mask = (repeat_gt_d != 0)

                        # if torch.sum(d_diff == 0)>0:
                        #     exact_d = (d_diff==0).nonzeros()
                        #     print(exact_d.shape)
                        lower_depth_mask = (d_diff <= 0)
                        lower_depth_mask_2 = torch.zeros_like(d_diff)
                        lower_depth_mask_2[d_diff > 0] = -1e10
                        lower_depth_idx = torch.argmax(d_diff * lower_depth_mask + lower_depth_mask_2, dim=1, keepdim=True)
                        lower_tmp = (d_diff * lower_depth_mask + lower_depth_mask_2).gather(1, lower_depth_idx)
                        lower_depth = depth_values_l.gather(1, lower_depth_idx)

                        higher_depth_mask = (d_diff >= 0)
                        higher_depth_mask_2 = torch.zeros_like(d_diff)
                        higher_depth_mask_2[d_diff < 0] = 1e10
                        higher_depth_idx = torch.argmin(d_diff * higher_depth_mask + higher_depth_mask_2, dim=1, keepdim=True)
                        higher_tmp = (d_diff * higher_depth_mask + higher_depth_mask_2).gather(1, higher_depth_idx)
                        higher_depth = depth_values_l.gather(1, higher_depth_idx)

                        check_low_mask = (lower_tmp == -1e10)
                        check_high_mask = (higher_tmp == 1e10)
                        x_low, x_high = torch.zeros_like(depth_gt_l), torch.zeros_like(depth_gt_l)
                        # d1 gt d2
                        x_high[(check_low_mask+check_high_mask)==0] = ((depth_gt_l-lower_depth) / (higher_depth-lower_depth))[(check_low_mask+check_high_mask)==0]
                        x_low[(check_low_mask+check_high_mask)==0] = (1 - ((depth_gt_l-lower_depth) / (higher_depth-lower_depth)))[(check_low_mask+check_high_mask)==0]
                        # # gt d1 d2
                        # x_high[check_low_mask==1] = (depth_gt_l / lower_depth)[check_low_mask==1]
                        # x_low[check_low_mask==1] = 0.
                        # # d1 d2 gt
                        # x_high[check_high_mask==1] = 0.
                        # x_low[check_high_mask==1] = (depth_gt_l / higher_depth)[check_high_mask==1] # val > 1
                        assert x_high[(check_low_mask+check_high_mask)==2].shape[0] == 0
                        x_high[x_high == -float("Inf")], x_low[x_low == float("Inf")] = 0, 0
                        x_high_nan_mask, x_low_nan_mask = x_high.isnan(), x_low.isnan()
                        x_high_NOT_nan_mask, x_low_NOT_nan_mask = x_high_nan_mask*(-1)+1, x_low_nan_mask*(-1)+1
                        x_high, x_low = x_high.nan_to_num(), x_low.nan_to_num()
                        
                        pseudo_prob = torch.zeros_like(depth_prob_l)
                        pseudo_prob = pseudo_prob.scatter(1, higher_depth_idx, x_high)
                        pseudo_prob = pseudo_prob.scatter(1, lower_depth_idx, x_low)
                        pseudo_prob = gt_mask * pseudo_prob

                        mask = (gt_mask) * ((check_low_mask+check_high_mask)==0) * x_high_NOT_nan_mask * x_low_NOT_nan_mask
                        if torch.isnan(pseudo_prob).sum() > 0:
                            print(torch.isnan(pseudo_prob).sum(), torch.isnan(x_high).sum(), torch.isnan(x_low).sum())
                        # P_loss += (torch.nn.functional.binary_cross_entropy(depth_prob_l, pseudo_prob, reduction="none") * mask).sum() / mask.sum()
                        P_loss += torch.sum(((depth_prob_l - pseudo_prob) * mask) ** 2) / mask.sum()

                loss += self.hparams.P_loss_lamb * P_loss / (3*3)
                self.log("train/P_loss", P_loss.item(), prog_bar=True)

            if self.hparams.unimvs_loss and isinstance(batch["depths"], dict):
                unimvs_loss = 0
                dlossw = [2.0, 1.0, 0.5]
                fl_gamas = [0, 1, 2]
                fl_alphas = [0.25, 0.5, 0.75]
                depth_probs = other_kwarg['depth_probs']
                for v in range(nb_views):
                    for l in range(3): # level 0 - 2: ori_resolution - 1/4_resol
                        depth_prob_l = depth_probs[f"level_{l}"].squeeze()[v:v+1, ...].float() # (b, d, h, w)
                        depth_values_l = depth_values[f"level_{l}"][:,v,...].float() # (b, d, h, w)
                        depth_gt_l = batch["depths"][f"level_{l}"][:, v:v+1].squeeze(0)  # (b, h, w)
                        interval_l = other_kwarg['depth_interval_levels'][f"level_{l}"][v] # float
                        gamma = fl_gamas[l]
                        alpha = fl_alphas[l]
                        weight = dlossw[l]
                        depth_mask = depth_gt_l > 0
                        unimvs_loss += unified_focal_loss(depth_prob_l, depth_values_l, interval_l, depth_gt_l, depth_mask, weight, gamma, alpha)

                loss += self.hparams.unimvs_loss_lamb * unimvs_loss
                self.log("train/unimvs_loss", unimvs_loss.item(), prog_bar=True)

            if self.hparams.density_01loss:
                novel_T = renderer_other_output['T']
                novel_alpha = renderer_other_output['alpha']
                density_01loss_val = -torch.mean((novel_T - 0.5) ** 2)
                density_01loss_val += -torch.mean((novel_alpha - 0.5) ** 2)
                
                loss += self.hparams.density_01loss_lamb * density_01loss_val
                self.log("train/density_01loss", density_01loss_val.item(), prog_bar=True)

            if self.hparams.edge_loss:
                # Edge generate
                edges_all = []
                for v in range(nb_views+1):
                    img = (unpre_imgs[0, v].detach().cpu().permute(1,2,0).numpy() * 255).astype(np.uint8)
                    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).astype(np.uint8)
                    img_blur = cv2.GaussianBlur(img_gray, (3,3), 0) 
                    edges = cv2.Canny(image=img_blur, threshold1=100, threshold2=200)

                    kernel = np.ones((3,3), np.uint8)
                    dilate_iter = 2 if self.hparams.dataset_name == 'dtu' else 1
                    edges = cv2.dilate(edges, kernel, iterations = dilate_iter)
                    edges = torch.from_numpy(edges.astype(float) / 255).long()
                    # edges = (edges-1)*(-1)
                    edges_all.append(edges == 1)

                # add edge mask
                tgt_edge_mask = edges_all[-1][rays_pixs.long()[0],rays_pixs.long()[1]]
                if isinstance(batch["depths"], dict):
                    depth_gt_edge = copy.deepcopy(batch["depths"])
                    for l in range(3):
                        for v in range(nb_views):
                            edges_l = cv2.resize(edges_all[v].float().detach().cpu().numpy(), None, fx=1.0/(2**l), fy=1.0/(2**l), interpolation=cv2.INTER_NEAREST,)
                            edges_l = torch.from_numpy(edges_l).bool()
                            depth_gt_edge[f"level_{l}"][:,v] = batch["depths"][f"level_{l}"][:,v] * edges_l.unsqueeze(0).cuda()
                    
                    rays_gt_depth_edge = rays_gt_depth[tgt_edge_mask]
                else:
                    rays_pixs_edge = rays_pixs[:,tgt_edge_mask]
                
                rays_pixs_edge = rays_pixs[:,tgt_edge_mask]
                
                rays_gt_rgb_edge = rays_gt_rgb[tgt_edge_mask,:]
                rendered_rgb_edge = rendered_rgb[tgt_edge_mask,:]
                rendered_depth_edge = rendered_depth[tgt_edge_mask]

                edge_loss_val = torch.tensor(0.).cuda()
                if isinstance(batch["depths"], dict):
                    edge_loss_val = edge_loss_val + 1 * self.depth_loss(depth_map, depth_gt_edge)
                else:
                    edge_loss_val = edge_loss_val + 0.1 * self_supervision_loss(
                        self.depth_loss,
                        rays_pixs_edge,
                        rendered_depth_edge.detach(),
                        depth_map,
                        rays_gt_rgb_edge,
                        unpre_imgs,
                        rendered_rgb_edge.detach(),
                        batch["intrinsics"],
                        batch["c2ws"],
                        batch["w2cs"],
                    ) #TODO check only compute edge pixels
                
                if isinstance(batch["depths"], dict):
                    edge_loss_val = edge_loss_val + 0.1 * self.depth_loss(rendered_depth_edge, rays_gt_depth_edge)
                
                edge_loss_val = edge_loss_val + img2mse(rendered_rgb_edge, rays_gt_rgb_edge)
                
                self.log("train/edge_loss", edge_loss_val.item(), prog_bar=True)
                
            if self.hparams.novel_depth_loss and not first_time:
                assert for_novel_depth_loss != None
                _2d = for_novel_depth_loss["rays_pixs"].long()
                cas_depth = depth_map["level_0"][0,0,_2d[0],_2d[1]]
                SL1 = torch.nn.SmoothL1Loss(reduction="mean")
                novel_depth_loss_val = 0.1 * SL1(cas_depth, for_novel_depth_loss["rendered_depth"])
                loss += novel_depth_loss_val
                self.log("train/novel_depth_loss", novel_depth_loss_val.item(), prog_bar=False)

            if self.hparams.renderer_geneRGBsigma_loss:
                ### gene_rgb
                gene_rendered_rgb = renderer_other_output['gene_rendered']['gene_rendered_rgb']
                gene_mse_loss = img2mse(gene_rendered_rgb, rays_gt_rgb)

                loss += 0.5*gene_mse_loss
                self.log("train/gene_mse_loss", gene_mse_loss.item(), prog_bar=True)

            ##### casMVS depth start #####
            # Supervising depth maps with either ground truth depth or self-supervision loss
            ## This loss is only used in the generalizable model
            if self.hparams.scene == "None":
                ## if ground truth is available
                if isinstance(batch["depths"], dict) and not self.hparams.wo_depth_gt_supervise:
                    d_loss = self.depth_loss(depth_map, batch["depths"])
                    if d_loss != 0:
                        self.log("train/dlossgt", d_loss.item(), prog_bar=False)
                    
                    if self.hparams.cas_confi:
                        d_w_loss = self.depth_loss_with_confi(depth_map, batch["depths"], confi=other_kwarg['depth_confi'])
                        if d_w_loss != 0:
                            self.log("train/dlossgt_w", d_w_loss.item(), prog_bar=False)
                        d_loss = (d_loss + d_w_loss*100)/2
                        # d_loss = (d_loss + d_w_loss*1000)/2

                        # confi regularization
                        # confi_reg_loss = 0
                        # for l in range(3):
                        #     ones = torch.ones_like(other_kwarg['depth_confi'][f"level_{l}"])
                        #     confi_reg_loss += torch.mean((other_kwarg['depth_confi'][f"level_{l}"] - ones).abs())
                        # print("r",confi_reg_loss)
                        # loss += confi_reg_loss
                    
                    loss = loss + 1 * d_loss
                    
                else:
                    d_loss = self_supervision_loss(
                        self.depth_loss,
                        rays_pixs,
                        rendered_depth.detach(),
                        depth_map,
                        rays_gt_rgb,
                        unpre_imgs,
                        rendered_rgb.detach(),
                        batch["intrinsics"],
                        batch["c2ws"],
                        batch["w2cs"],
                    )
                    loss = loss + 0.1 * d_loss
                    if d_loss != 0:
                        self.log("train/dlosspgt", d_loss.item(), prog_bar=False)
            
            if self.hparams.extra_depth_loss:
                extra_d_loss, L_smooth = 0, 0
                for l in range(3):
                    for v in range(nb_views):
                        imgs_l = F.interpolate(unpre_imgs[:, v], scale_factor=1.0 / (2**l), mode="bilinear", align_corners=True)
                        L_smooth = depth_smoothness(depth_map[f"level_{l}"][:, v:v+1],imgs_l) * 2 ** (1 - l)
                
                loss += self.hparams.L_smooth_lamb * L_smooth
                self.log("train/L_smooth", L_smooth.item(), prog_bar=True)
            
            ##### casMVS depth end #####
            
            if self.hparams.renderer_geneRGBsigma_dist:
                uct_loss, uct_exist = torch.tensor(0.).cuda(), 0
                rendered_rgb_mean = renderer_other_output['rendered_dist_args']['rendered_rgb_mean']
                rendered_rgb_var = renderer_other_output['rendered_dist_args']['rendered_rgb_var']
                density_sum = torch.mean(renderer_other_output['density'])
                inverse_sigmoid_rgb_gt = inverse_sigmoid(rays_gt_rgb)
                for c in range(3): #rgb
                    syn_img_mse = torch.mean((rendered_rgb_mean[:,c]-inverse_sigmoid_rgb_gt[:,c])**2 / (2*rendered_rgb_var[:,c]))
                    syn_var = torch.mean(torch.log(rendered_rgb_var[:,c]) / 2)
                    # uct_loss_c = syn_img_mse + 0.01*syn_var
                    uct_loss_c = 0.00001*syn_img_mse + 0.01*syn_var # modify4
                    if not (torch.isnan(uct_loss_c) or torch.isinf(uct_loss_c)):
                        uct_loss += uct_loss_c
                        uct_exist += 1
                
                if uct_exist != 0:
                    uct_loss /= uct_exist

                loss += self.hparams.uct_lamb*uct_loss
                self.log("train/uct_loss", uct_loss.item(), prog_bar=True)
                if uct_exist != 0:
                    self.log("train/syn_img_mse", syn_img_mse.item(), prog_bar=False)
                    self.log("train/syn_var", syn_var.item(), prog_bar=False)
                # loss += self.hparams.density_sum_lamb*density_sum
                # self.log("train/density_sum", density_sum.item(), prog_bar=True)

            
            mask = rays_gt_depth > 0
            depth_available = mask.sum() > 0

            ## Supervising ray depths
            if depth_available and not self.hparams.wo_depth_gt_supervise:
                ## This loss is only used in the generalizable model
                if self.hparams.scene == "None":
                    loss = loss + 0.1 * self.depth_loss(rendered_depth, rays_gt_depth)

                self.log(
                    f"train/acc_l_{self.eval_metric[0]}mm",
                    acc_threshold(
                        rendered_depth, rays_gt_depth, mask, self.eval_metric[0]
                    ).mean(),
                    prog_bar=False,
                )
                self.log(
                    f"train/acc_l_{self.eval_metric[1]}mm",
                    acc_threshold(
                        rendered_depth, rays_gt_depth, mask, self.eval_metric[1]
                    ).mean(),
                    prog_bar=False,
                )
                self.log(
                    f"train/acc_l_{self.eval_metric[2]}mm",
                    acc_threshold(
                        rendered_depth, rays_gt_depth, mask, self.eval_metric[2]
                    ).mean(),
                    prog_bar=False,
                )

                abs_err = abs_error(rendered_depth, rays_gt_depth, mask).mean()
                self.log("train/abs_err", abs_err, prog_bar=False)

            if self.hparams.patch_nerf_depth_tvLoss:
                assert self.hparams.train_patch == True
                patch_w = int(self.hparams.batch_size ** 0.5)
                L_smooth = depth_smoothness(rendered_depth.reshape(1, 1, patch_w, patch_w),rays_gt_rgb.reshape(1, patch_w, patch_w, 3).permute(0,3,1,2))
                
                loss += self.hparams.patch_nerf_depth_tvLoss_lamb * L_smooth
                self.log("train/nerfDep_smooth", L_smooth.item(), prog_bar=True)

            if self.hparams.vgg_loss:
                assert self.hparams.train_patch == True
                patch_w = int(self.hparams.batch_size ** 0.5)
                perceptual = VGG16_perceptual().cuda()
                syn0, syn1, syn2, syn3 = perceptual(rendered_rgb.reshape(1, patch_w, patch_w, 3).permute(0,3,1,2))
                r0, r1, r2, r3 = perceptual(rays_gt_rgb.reshape(1, patch_w, patch_w, 3).permute(0,3,1,2))
                per_loss = 0
                per_loss += img2mse(syn0,r0)
                per_loss += img2mse(syn1,r1)
                per_loss += img2mse(syn2,r2)
                per_loss += img2mse(syn3,r3)

                if self.hparams.vgg_loss_style:
                    syn0_s, syn1_s, syn2_s, syn3_s = perceptual(renderer_other_output["rendered_style_rgb"].reshape(1, patch_w, patch_w, 3).permute(0,3,1,2))
                    r0_s, r1_s, r2_s, r3_s = perceptual(other_rays_output["rays_pseudo_style_rgb"].reshape(1, patch_w, patch_w, 3).permute(0,3,1,2))
                    per_loss += img2mse(syn0_s,r0_s)
                    per_loss += img2mse(syn1_s,r1_s)
                    per_loss += img2mse(syn2_s,r2_s)
                    per_loss += img2mse(syn3_s,r3_s)

                    per_loss /= 2

                loss += self.hparams.vgg_loss_lamb * per_loss
                self.log("train/vgg_loss", per_loss.item(), prog_bar=True)

            if self.hparams.patch_cs_loss:
                assert self.hparams.train_patch == True
                patch_w = int(self.hparams.batch_size ** 0.5)
                cs_loss_val = 0
                # content
                render_style_patch = mdmm_normalize(renderer_other_output["rendered_style_rgb"].reshape(1, patch_w, patch_w, 3).permute(0,3,1,2))
                gt_patch = mdmm_normalize(rays_gt_rgb.reshape(1, patch_w, patch_w, 3).permute(0,3,1,2))
                patch_render_content = self.MDMM.enc_c.forward(render_style_patch, return_level=True)
                patch_gt_content = self.MDMM.enc_c.forward(gt_patch, return_level=True)
                for l in range(3):
                    cs_loss_val += self.hparams.PCSLoss_C_lamb * img2mse(patch_render_content[f"level_{l}"], patch_gt_content[f"level_{l}"])

                # style
                domain = batch["style_img_label"]
                domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                domain_vec[:,domain] = 1
                patch_render_style = self.MDMM.enc_a.forward(render_style_patch, domain_vec) # mean, var
                # patch_gt_style = self.MDMM.enc_a.forward(other_rays_output["rays_pseudo_style_rgb"].reshape(1, patch_w, patch_w, 3).permute(0,3,1,2), domain_vec) # mean, var
                ## v2
                patch_gt_style = self.MDMM.enc_a.forward(mdmm_style_images, domain_vec) # mean, var
                patch_render_style = torch.cat(patch_render_style, dim=-1)
                patch_gt_style = torch.cat(patch_gt_style, dim=-1)
                cs_loss_val += self.hparams.PCSLoss_S_lamb * img2mse(patch_render_style, patch_gt_style)

                loss += cs_loss_val
                self.log("train/PCS_loss", cs_loss_val.item(), prog_bar=True)

            ## Reconstruction loss
            mse_loss = img2mse(rendered_rgb, rays_gt_rgb)
            # if not (self.hparams.renderer_geneRGBsigma_dist and not first_time):
            loss = loss + mse_loss

            if self.hparams.geonerfMDMM and (not self.hparams.wo_style_loss):
                if not (self.hparams.NVSThenStyle and self.wr_cntr//2 < 10):
                    style_mse_loss = img2mse(renderer_other_output["rendered_style_rgb"], other_rays_output["rays_pseudo_style_rgb"])
                    if self.hparams.adaptive_style_loss:
                        # adaptive_weight = torch.exp(-loss_style_cls)
                        adaptive_weight = torch.exp(-3*loss_style_cls)
                        loss += self.hparams.style_lamb * adaptive_weight * style_mse_loss
                    else:
                        loss += self.hparams.style_lamb * style_mse_loss
                    self.log("train/style_mse_loss", style_mse_loss.item(), prog_bar=True)

            if self.hparams.update_casmvs:
                loss = d_loss if isinstance(batch["depths"], dict) else (0.1*d_loss)
            
            if self.hparams.learn_3dfeat_from_GT:
                # loss = 0
                for l in range(3):
                    feats_s_l = feats_vol[f"level_{l}"]
                    feats_t_l = feats_vol_t[f"level_{l}"].detach()
                    loss += 0.01*img2mse(feats_t_l, feats_s_l)
                
                self.log("train/3dfeat_st", loss, prog_bar=True)

            if self.hparams.edge_loss:
                if edge_loss_val != 0:
                    loss = (1-self.hparams.edge_loss_weight)*loss + self.hparams.edge_loss_weight*edge_loss_val

            return_list = [loss, mse_loss, {"geo_reasoner_output":None}]
            if self.hparams.cycle_loss and first_time:
                return_list[2]["rendered_rgb"] = rendered_rgb
                return_list[2]["rays_pixs"] = rays_pixs
                
                with torch.no_grad():
                    surface_points = rays_os + rendered_depth.unsqueeze(-1) * rays_dir # (bs=1024, 3) world coord.
                    w2c_ref, intrinsics_ref = batch["w2cs"][0, 0], batch["intrinsics"][0, 0] # view 0
                    R = w2c_ref[:3, :3]  # (3, 3)
                    T = w2c_ref[:3, 3:]  # (3, 1)
                    _points_3d = torch.matmul(surface_points, R.t()) + T.reshape(1, 3)
                    _points_2d = _points_3d @ intrinsics_ref.t()
                    _points_2d_v0 = _points_2d[:,:2] / _points_2d[:,2:]
                    _points_2d_v0 = _points_2d_v0.permute(1,0) # (2(x,y), bs)
                    points_2d_v0 = torch.zeros_like(_points_2d_v0)
                    points_2d_v0[0], points_2d_v0[1] = _points_2d_v0[1], _points_2d_v0[0] # (2(y,x), bs)
                    points_2d_v0 = torch.nan_to_num(points_2d_v0, nan=0)
                    points_2d_v0[points_2d_v0 == float("Inf")] = 0
                    if torch.sum(torch.isnan(points_2d_v0)): print("nan!",torch.sum(torch.isnan(points_2d_v0)))
                    if torch.sum(torch.isinf(points_2d_v0)): print("inf!",torch.sum(torch.isinf(points_2d_v0)))
                    points_2d_v0[0], points_2d_v0[1] = torch.clamp(points_2d_v0[0],0,H-1), torch.clamp(points_2d_v0[1],0,W-1)
                    return_list[2]["points_2d_v0"] = points_2d_v0
                    # _img = copy.deepcopy(unpre_imgs[:, 0])
                    # _img[0, :, points_2d_v0[0][0:10].long(), points_2d_v0[1][0:10].long()] = torch.ones_like(rendered_rgb[0:10].permute(1,0))
                    # cv2.imwrite("_v0.png",cv2.cvtColor(_img[0].permute(1,2,0).detach().cpu().numpy()*255, cv2.COLOR_RGB2BGR) )
                    # a

                if self.hparams.novel_depth_loss:
                    return_list[2]["rendered_depth"] = rendered_depth

            if self.hparams.far_view_loss and first_time:
                return_list[2]["geo_reasoner_output"] = geo_reasoner_output


            return return_list

        def train_update_z(batch):
            assert self.hparams.geonerfMDMM == True and self.hparams.styleTwoBranch == True
            nb_views = self.hparams.nb_views
            H, W = batch["images"].shape[-2:]
            H, W = int(H), int(W)
            loss = 0

            self.MDMM.eval()

            unpre_imgs = self.unpreprocess(batch["images"])
            unpre_style_img = self.unpreprocess(batch["style_img"])
            mdmm_normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            mdmm_images = mdmm_normalize(unpre_imgs)
            mdmm_style_images = mdmm_normalize(unpre_style_img)[0]

            with torch.no_grad():
                _content_feat = []
                for v in range(nb_views):
                    _content_feat.append(self.MDMM.test_forward_random(mdmm_images[:,v], return_content=True))
                
                content_feats = {}
                for l in range(3):
                    content_feats[f"level_{l}"] = torch.cat([_content_feat[v][f"level_{l}"] for v in range(nb_views)], dim=0) # (V, c, h, w)
                domain = batch["style_img_label"]
                domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                domain_vec[:,domain] = 1
                style_novel_normalize = self.MDMM.test_forward_transfer(mdmm_images[:,-1], mdmm_style_images, domain_vec)[0] #(3, H, W)
                style_feat_domain = {'domain':domain_vec[0], 'style_feat':self.MDMM.z_attr[0]} # (n_domain) (8)
                style_novel = style_novel_normalize*0.5 + 0.5 # denormalize: range [0,1] # (3, H, W)

            ## Inferring Geometry Reasoner
            geo_reasoner_output = self.geo_reasoner(
                imgs=batch["images"][:, :nb_views],
                affine_mats=batch["affine_mats"][:, :nb_views],
                affine_mats_inv=batch["affine_mats_inv"][:, :nb_views],
                near_far=batch["near_fars"][:, :nb_views],
                closest_idxs=batch["closest_idxs"][:, :nb_views],
                gt_depths=batch["depths_aug"][:, :nb_views],
                gt_depths_real=batch["depths"],
                content_feats=content_feats if self.hparams.geonerfMDMM else None,
                content_style_feat = {'content':content_feats, 'style':style_feat_domain} if self.hparams.geonerfMDMM and self.hparams.style3Dfeat else None
            )

            feats_vol, feats_fpn, depth_map, depth_values, other_kwarg = geo_reasoner_output

            ## Normalizing depth maps in NDC coordinate
            depth_map_norm = {}
            for l in range(3):
                depth_map_norm[f"level_{l}"] = (
                    depth_map[f"level_{l}"].detach() - depth_values[f"level_{l}"][:, :, 0]
                ) / (
                    depth_values[f"level_{l}"][:, :, -1]
                    - depth_values[f"level_{l}"][:, :, 0]
                )

            unpre_imgs = self.unpreprocess(batch["images"])
            ## render input stylized image
            rays_pts_ndc_inputView = get_rays_pts(
                                        H,
                                        W,
                                        batch["c2ws"],
                                        batch["w2cs"],
                                        batch["intrinsics"],
                                        batch["near_fars"],
                                        depth_values,
                                        self.hparams.nb_coarse,
                                        self.hparams.nb_fine,
                                        nb_views=nb_views,
                                        train=True,
                                        train_batch_size=self.hparams.batch_size,
                                        target_img=unpre_imgs[0, -1],
                                        target_depth=batch["depths_h"][0, -1],
                                        train_patch=self.hparams.train_patch,
                                        style_novel=style_novel if self.hparams.geonerfMDMM else None,
                                        inputWhole_cs_loss=True,
                                        pred_depth_map=depth_map["level_0"][0,0], # depth_map: (1,V,H,W)
                                    )[2]

            feat_to_cat = []

            domain = batch["input_img_label"]
            domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
            domain_vec[:,domain] = 1
            style, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], domain_vec)
            style, domain = style[0], domain_vec[0]

            common_feat, special_feat = other_kwarg['common_feat'], other_kwarg['special_feat']
            for l in range(3):
                common_feat_l = common_feat[f"level_{l}"].unsqueeze(0)
                ray_common_l, _, _ = interpolate_2D(
                    common_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                )
                feat_to_cat.append(ray_common_l.unsqueeze(1))

            for l in range(3):
                special_feat_l = special_feat[f"level_{l}"].unsqueeze(0)
                ray_special_l, _, _ = interpolate_2D(
                    special_feat_l[:, 0], unpre_imgs[:, :nb_views][:, 0], rays_pts_ndc_inputView[f"level_0"][:, 0:1, 0]
                )
                feat_to_cat.append(ray_special_l.unsqueeze(1))
            
            N, S = rays_pts_ndc_inputView["level_0"].shape[:2]
            style = style.expand(N, 1, -1)
            domain = domain.expand(N, 1, -1)
            feat_to_cat.append(style)
            feat_to_cat.append(domain)
            interpolated_feats_one = torch.cat(feat_to_cat, dim=-1).unsqueeze(2) # (N,1,1,c) ; normal interpolated_feats: (N,S,V,c)

            _, style_rgb_ray = self.renderer(None, interpolated_feats_one, None, None, onlyContentStyle=True, return_what_style='both', z=None)

            # common+special
            stylized_input_cs = style_rgb_ray['c+s'].reshape(H, W, 3).permute(2,0,1).unsqueeze(0) # (1,3,H,W)
            stylized_input_cs_mdmm = mdmm_normalize(stylized_input_cs)

            # print(self.renderer.z)
            # loss
            mse_loss = img2mse(stylized_input_cs, unpre_imgs[:,0])
            loss = 10*mse_loss
            self.log("train/mse_loss", loss.item(), prog_bar=True)

            return loss, mse_loss


        if self.hparams.cycle_loss:
            loss, mse_loss, other_dir = train_(batch, first_time=True)
            new_batch = alter_batch(batch, other=other_dir, cycle=self.hparams.cycle_loss)
            for_novel_depth_loss = {}
            if self.hparams.novel_depth_loss:
                for_novel_depth_loss["rendered_depth"] = other_dir["rendered_depth"]
                for_novel_depth_loss["rays_pixs"] = other_dir["rays_pixs"]
            new_loss, new_mse_loss, _ = train_(new_batch, points_2d=other_dir["points_2d_v0"], for_novel_depth_loss=for_novel_depth_loss)
            loss += self.hparams.cycle_loss_lamb * new_loss

        elif self.hparams.far_view_loss:
            assert isinstance(batch, list)
            loss, mse_loss, other_dir = train_(batch[0], first_time=True)
            new_loss, new_mse_loss, _ = train_(batch[1], geo_reasoner_output=other_dir["geo_reasoner_output"])
            loss += self.hparams.far_view_loss_lamb * new_loss

        else:
            if self.hparams.update_z:
                loss, mse_loss = train_update_z(batch)
            else:
                loss, mse_loss, _ = train_(batch)

        
        with torch.no_grad():
            self.log("train/loss", loss.item(), prog_bar=True)
            psnr = mse2psnr(mse_loss.detach())
            self.log("train/PSNR", psnr.item(), prog_bar=True)
            self.log("train/img_mse_loss", mse_loss.item(), prog_bar=True)
            if self.hparams.cycle_loss or self.hparams.far_view_loss:
                name = "cycle" if self.hparams.cycle_loss else "far"
                self.log(f"train/{name}_loss", new_loss.item(), prog_bar=True)
                psnr = mse2psnr(new_mse_loss.detach())
                self.log(f"train/{name}_PSNR", psnr.item(), prog_bar=False)
                self.log(f"train/{name}_img_mse_loss", new_mse_loss.item(), prog_bar=False)

        # Manual Optimization
        self.manual_backward(loss)

        for i, (n,v) in enumerate(self.geo_reasoner.named_parameters()):
            if v.grad != None and len(v.shape) >= 4 :
                norm = v.grad.detach().data.norm(2)
                # print(f"geo-norm-{i}",norm)
                self.log(f"grad/geo-norm-{i}", norm, prog_bar=False)

        for i, v in enumerate(self.renderer.parameters()):
            if v.grad != None and len(v.shape) == 2 :
                norm = v.grad.detach().data.norm(2)
                # print(f"renderer-norm-{i}",norm)
                self.log(f"grad/renderer-norm-{i}", norm, prog_bar=False)

        opt = self.optimizers()
        sch = self.lr_schedulers()

        # Warming up the learning rate
        if self.trainer.global_step < self.hparams.warmup_steps:
            lr_scale = min(
                1.0, float(self.trainer.global_step + 1) / self.hparams.warmup_steps
            )
            for pg in opt.param_groups:
                pg["lr"] = lr_scale * self.learning_rate

        self.log("train/lr", opt.param_groups[0]["lr"], prog_bar=False)

        opt.step()
        opt.zero_grad()
        sch.step()

        return {"loss": loss}

    def validation_step(self, batch, batch_nb):

        if self.hparams.scene == "None":
            pair_num = 5
            novel_num = 1
        elif self.hparams.scene.split('/')[-1] == "Family":
            pair_num = 8
            novel_num = 0

        elif self.hparams.scene.split('/')[-1] == "Horse":
            pair_num = 2
            novel_num = 1

        elif self.hparams.scene.split('/')[-1] == "Playground":
            pair_num = 9
            novel_num = 0

        elif self.hparams.scene.split('/')[-1] == "Train":
            pair_num = 11
            novel_num = 0

        if self.hparams.scene.split('/')[-1] == "Family":
            pair_num = 8
            novel_num = 1

        elif self.hparams.scene.split('/')[-1] == "Horse":
            pair_num = 2
            novel_num = 1

        elif self.hparams.scene.split('/')[-1] == "Playground":
            pair_num = 9
            novel_num = 1

        

        if not (self.wr_cntr//2 == pair_num and self.wr_cntr%2 == novel_num ):
            print(f'Skip: pair {self.wr_cntr//2}, novel{self.wr_cntr%2}')
            self.wr_cntr+=1
            return

        # os.makedirs(
        #     f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/evaluation/",
        #     exist_ok=True,
        # )

        ## This makes Batchnorm to behave like InstanceNorm
        self.geo_reasoner.train()

        log_keys = [
            "val_psnr",
            "val_ssim",
            "val_lpips",
            "val_depth_loss_r",
            "val_abs_err",
            "mask_sum",
        ] + [f"val_acc_{i}mm" for i in self.eval_metric]
        for l in range(3):
            log_keys += [f"val_abs_err_casMVS_level{l}", f"mask_sum_casMVS_level{l}"]
            log_keys += [f"val_acc_{self.eval_metric[0]}mm_casMVS_level{l}", f"val_acc_{self.eval_metric[1]}mm_casMVS_level{l}", f"val_acc_{self.eval_metric[2]}mm_casMVS_level{l}"]
        log = {}
        log = init_log(log, log_keys)

        H, W = batch["images"].shape[-2:]
        H, W = int(H), int(W)

        nb_views = self.hparams.nb_views

        with torch.no_grad():
            if self.hparams.use_midas:
                _imgs = self.unpreprocess(batch["images"])
                output_imgs = []
                for v in range(nb_views+1): # midas
                    raw_img = _imgs[0][v].permute(1,2,0).cpu().numpy()
                    import skimage
                    raw_img = skimage.img_as_ubyte(raw_img)
                    input_batch = self.midas_transform(raw_img).cuda()
                    with torch.no_grad():
                        prediction = self.midas(input_batch)
                        prediction = torch.nn.functional.interpolate(
                            prediction.unsqueeze(1),
                            size=raw_img.shape[:2],
                            mode="bicubic",
                            align_corners=False,
                        ).squeeze()
                        depth_scaling = 40000
                        min_inv_depth = 50
                        print(prediction.cpu().numpy().max())
                        print(prediction.cpu().numpy().min())
                        inv_depth = np.maximum(min_inv_depth, prediction.cpu().numpy())
                        depth = depth_scaling / inv_depth
                        print(depth.max(),depth.min())
                        depth = (depth-depth.min()) / (depth.max()-depth.min()) #* 6
                        print(depth.max(),depth.min())
                        ###midasloss_try-normal-nearfar
                        # depth = (depth-depth.min()) / (depth.max()-depth.min()) * (14-(1e-10)) + (1e-10)
                        ###
                        midas_depth = depth#torch.from_numpy(depth).cpu()

                    output_imgs += [midas_depth]

                depth_img_vis = np.concatenate(output_imgs,axis=1)
                plt.clf()
                fig = plt.figure()
                ax = plt.axes()
                ax.set_title("input imgs: midas_depth(input view 3 / novel view)")
                plt.axis('off')
                im = ax.imshow(depth_img_vis, cmap='jet')
                cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])
                plt.colorbar(im,cax=cax)
                os.makedirs(f"{self.hparams.logdir}/{self.hparams.dataset_name}/midas/", exist_ok=True)
                plt.savefig(f"{self.hparams.logdir}/{self.hparams.dataset_name}/midas/{self.global_step:08d}_{self.wr_cntr:02d}.png", bbox_inches='tight', dpi=300)
                self.wr_cntr += 1

                return log

            if self.hparams.contentFeature:
                unpre_imgs = self.unpreprocess(batch["images"])
                mdmm_normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                mdmm_images = mdmm_normalize(unpre_imgs)

                _content_feat = []
                for v in range(nb_views):
                    _content_feat.append(self.MDMM.test_forward_random(mdmm_images[:,v], return_content=True))
                
                content_feats = {}
                for l in range(3):
                    content_feats[f"level_{l}"] = torch.cat([_content_feat[v][f"level_{l}"] for v in range(nb_views)], dim=0) # (V, c, h, w)

            if self.hparams.geonerfMDMM:
                self.MDMM.eval()

                unpre_imgs = self.unpreprocess(batch["images"])
                # unpre_style_img = self.unpreprocess(batch["style_img"])
                mdmm_normalize = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                mdmm_images = mdmm_normalize(unpre_imgs)
                # mdmm_style_images = mdmm_normalize(unpre_style_img)[0]

                _content_feat = []
                for v in range(nb_views):
                    _content_feat.append(self.MDMM.test_forward_random(mdmm_images[:,v], return_content=True))
                
                content_feats = {}
                for l in range(3):
                    content_feats[f"level_{l}"] = torch.cat([_content_feat[v][f"level_{l}"] for v in range(nb_views)], dim=0) # (V, c, h, w)
                # # domain = random.randint(0,self.hparams.num_domains-1)
                # domain = batch["style_img_label"]
                # domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                # domain_vec[:,domain] = 1
                # # style_novel_normalize = self.MDMM.test_forward_random(mdmm_images[:,-1], specify_domain=domain)[0][0]
                # # style_feat_domain = {'domain':domain_vec[0], 'style_feat':self.MDMM.z_random[0]} # (n_domain) (8)
                # style_novel_normalize = self.MDMM.test_forward_transfer(mdmm_images[:,-1], mdmm_style_images, domain_vec)[0] #(3, H, W)
                # style_feat_domain = {'domain':domain_vec[0], 'style_feat':self.MDMM.z_attr[0]} # (n_domain) (8)
                # style_novel = style_novel_normalize*0.5 + 0.5 # denormalize: range [0,1] # (3, H, W)

                ## output pred_class loss
                # if self.hparams.save_pred_cls_loss:
                #     resize_216 = transforms.Resize((216,216), Image.BICUBIC)
                #     style_novel_216 = resize_216(style_novel_normalize.unsqueeze(0)) # (1, 3, 216, 216)
                #     _, pred_cls = self.MDMM.dis1.forward(style_novel_216)
                #     # z_content = self.MDMM.enc_c.forward(style_novel_216)
                #     # if self.hparams.content_level:
                #     #     pred_cls = self.MDMM.disContent.forward(z_content["level_2"].detach())
                #     # else:
                #     #     pred_cls = self.MDMM.disContent.forward(z_content.detach())
                #     loss_cls = self.MDMM.cls_loss(pred_cls, domain_vec)
                #     idx2domain = {0:'night', 1:'sunny', 2:'rain', 3:'cloud', 4:'snow'}
                #     with open(
                #         f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/loss_cls.txt",
                #         "a",
                #     ) as f:
                #         f.write(f"{self.wr_cntr} {idx2domain[domain.item()]} {loss_cls}\n")

            if self.hparams.dataset_name != 'timeLapse':
                ## Inferring Geometry Reasoner
                geo_reasoner_output = self.geo_reasoner(
                    imgs=batch["images"][:, :nb_views],
                    affine_mats=batch["affine_mats"][:, :nb_views],
                    affine_mats_inv=batch["affine_mats_inv"][:, :nb_views],
                    near_far=batch["near_fars"][:, :nb_views],
                    closest_idxs=batch["closest_idxs"][:, :nb_views],
                    gt_depths=batch["depths_aug"][:, :nb_views],
                    gt_depths_real=batch["depths"],
                    content_feats=content_feats if self.hparams.geonerfMDMM or self.hparams.contentFeature else None,
                    content_style_feat = {'content':content_feats, 'style':None} if self.hparams.geonerfMDMM and self.hparams.style3Dfeat else None
                )
                
                feats_vol, feats_fpn, depth_map, depth_values, other_kwarg = geo_reasoner_output

                if self.hparams.save_D or self.hparams.save_O_label or self.hparams.save_depth_prob or self.hparams.save_O:
                    if self.hparams.save_D: keyword = 'disocc_confi'
                    elif self.hparams.save_O_label: keyword = 'O_label'
                    elif self.hparams.save_depth_prob: keyword = 'depth_probs'
                    elif self.hparams.save_O: keyword = 'O'
                    os.makedirs(
                        f"3d_feature/{self.hparams.dataset_name}/{self.hparams.expname}/",
                        exist_ok=True,
                    )
                    with open(f"3d_feature/{self.hparams.dataset_name}/{self.hparams.expname}/{self.hparams.scene}_{keyword}.pickle", 'wb') as f:
                        x = {}
                        for l in range(3):
                            if self.hparams.save_depth_prob:
                                x[f"level_{l}"] = F.softmax(other_kwarg[keyword][f'level_{l}'].detach().cpu(), dim=3) # (B, V, 1, D, h, w)
                                print(x[f"level_{l}"].max(),x[f"level_{l}"].min())
                            else:
                                x[f"level_{l}"] = other_kwarg[keyword][f'level_{l}'].detach().cpu()
                        pickle.dump(x, f)
                    print("level0:",x["level_0"].shape,"1:",x["level_1"].shape,"2:",x["level_2"].shape)
                    print("input view shape:",batch["images"][:, :nb_views].shape)
                    

                ## Normalizing depth maps in NDC coordinate
                depth_map_norm = {}
                for l in range(3):
                    depth_map_norm[f"level_{l}"] = (
                        depth_map[f"level_{l}"] - depth_values[f"level_{l}"][:, :, 0]
                    ) / (
                        depth_values[f"level_{l}"][:, :, -1]
                        - depth_values[f"level_{l}"][:, :, 0]
                    )
                
                depth_map_norm_gt = {}
                if self.hparams.gene_mask != "None" or self.hparams.occ_mask_useGT:
                    for l in range(3):
                        depth_map_norm_gt[f"level_{l}"] = (
                            batch["depths"][f"level_{l}"][:,0:nb_views] - depth_values[f"level_{l}"][:, :, 0]
                        ) / (
                            depth_values[f"level_{l}"][:, :, -1]
                            - depth_values[f"level_{l}"][:, :, 0]
                        )
                
                depth_map_norm_mix = {}
                if self.hparams.occ_mask_useGT:
                    for l in range(3):
                        _depth_gt = batch["depths"][f"level_{l}"][:,0:nb_views]
                        depth_map_norm_mix[f"level_{l}"] = depth_map_norm[f"level_{l}"]
                        depth_map_norm_mix[f"level_{l}"][_depth_gt>0] = depth_map_norm_gt[f"level_{l}"][_depth_gt>0]


                unpre_imgs = self.unpreprocess(batch["images"])


                rendered_rgb_all = []
                render_num = self.hparams.sample_num if self.hparams.sample else 1
                if self.hparams.zInputStyle: # val time: use input view style
                    input_domain = batch["input_img_label"]
                    input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    input_domain_vec[:,input_domain] = 1
                    z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)

                    # input_domain = batch["style_img_label"]
                    # input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    # input_domain_vec[:,input_domain] = 1
                    # z, _ = self.MDMM.enc_a.forward(mdmm_style_images, input_domain_vec)

                    z = self.renderer.style2remain(z)
                else:
                    dist = tdist.Normal(torch.tensor([0.]), torch.tensor([self.hparams.gauss_var]))
                    z = dist.sample((self.hparams.z_dim,)).reshape(1,self.hparams.z_dim).cuda()
                
                if self.hparams.z_zInputStyle_Fuse:
                    input_domain = batch["input_img_label"]
                    input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    input_domain_vec[:,input_domain] = 1

                    _z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)
                    t0 = self.renderer.style2phi(_z)

                    z_input, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)
                    z_input = self.renderer.style2remain(z_input)
                    z_input = self.renderer.z2commonSpace(z_input)

                    dist = tdist.Normal(torch.tensor([0.]), torch.tensor([self.hparams.gauss_var]))
                    z_N = dist.sample((self.hparams.z_dim,)).reshape(1,self.hparams.z_dim).cuda()
                    z_N = self.renderer.z2commonSpace(z_N)

                if self.hparams.input_phi_to_test:
                    input_phi_list = [((self.wr_cntr//self.hparams.n_output_views)/self.hparams.n_split)*(2*math.pi)]
                    rendered_style_rgb_cs_list = []
                elif self.hparams.to_calculate_consistency:
                    rendered_style_rgb_cs_list = []
                    input_phi_list = [(x/self.hparams.n_split)*(2*math.pi) for x in range(self.hparams.n_split)]
                    # input_phi_list = [(x/self.hparams.n_split)*(2*math.pi) for x in range(1)]
                elif self.hparams.phi_is_from_input:
                    rendered_style_rgb_cs_list= []
                    input_phi_list = []
                    

                    input_domain = batch["input_img_label"]
                    input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                    input_domain_vec[:,input_domain] = 1
                    _z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)
                    phi = self.renderer.style2phi(_z)
                    input_phi_list = [phi]
                    print("input t:",phi)
                else:
                    input_phi_list = [None]


                

                dir_name = f"./logs/{self.hparams.dataset_name}/{self.hparams.expname}"
                os.makedirs(dir_name, exist_ok=True)


                for id_phi, input_phi in enumerate(input_phi_list):
                    print('current Phi: ', input_phi)
                    if self.hparams.eval_from_img:
                        filename = f'{self.hparams.eval_from_img_path}/{self.global_step:08d}_{self.wr_cntr:02d}_mask.png'
                        rendered_rgb = T.ToTensor()(Image.open(filename)).cuda() # (3,H,W)
                        rendered_depth = torch.zeros((H,W)).cuda()
                    else:
                        if self.hparams.z_zInputStyle_Fuse:
                            if input_phi != None:
                                delta_t = delta_t_compute(input_phi, t0)
                                z = z_N * delta_t + z_input * (1-delta_t)
                            else:
                                z = z_input
                        
                        if self.hparams.tune_result_by_delta_t:
                            assert input_phi != None
                            input_domain = batch["input_img_label"]
                            input_domain_vec = torch.zeros((1,self.hparams.num_domains)).cuda()
                            input_domain_vec[:,input_domain] = 1

                            _z, _ = self.MDMM.enc_a.forward(mdmm_images[:,0], input_domain_vec)
                            t0 = self.renderer.style2phi(_z)

                            delta_t = delta_t_compute(input_phi, t0)
                            # alpha = 1-math.exp(-5*delta_t)
                            alpha = delta_t
                            print(t0, input_phi, alpha)
                            
                        rendered_rgb, rendered_depth = [], []
                        output_density, output_T, output_nerf_rgb, output_alpha, output_coord = [], [], [], [], []
                        output_style, output_style_cs = [], []



                        # for id_newpose, newc2w in enumerate(interpolated_c2w_list):
                        # rendered_rgb, rendered_depth = [], []
                        # output_style, output_style_cs = [], []
                        for chunk_idx in range(
                            H * W // self.hparams.chunk + int(H * W % self.hparams.chunk > 0)
                        ):
                            pts_depth, rays_pts, rays_pts_ndc, rays_dir, _, rays_gt_depth, _, _, _ = get_rays_pts(
                                H,
                                W,
                                batch["c2ws"],
                                batch["w2cs"],
                                batch["intrinsics"],
                                batch["near_fars"],
                                depth_values,
                                self.hparams.nb_coarse,
                                self.hparams.nb_fine,
                                nb_views=nb_views,
                                chunk=self.hparams.chunk,
                                chunk_idx=chunk_idx,
                                gene_mask=self.hparams.gene_mask,
                                target_depth=batch["depths_h"][0, -1],
                                new_c2ws=None
                            )

                            ## Rendering
                            if self.hparams.gene_mask != "None": depth_map_norm_input = depth_map_norm_gt
                            elif self.hparams.occ_mask_useGT: depth_map_norm_input = depth_map_norm_mix
                            else: depth_map_norm_input = depth_map_norm

                            
                            rend_rgb, rend_depth, renderer_other_output = render_rays(
                                c2ws=batch["c2ws"][0, :nb_views],
                                rays_pts=rays_pts,
                                rays_pts_ndc=rays_pts_ndc,
                                pts_depth=pts_depth,
                                rays_dir=rays_dir,
                                feats_vol=feats_vol,
                                feats_fpn=feats_fpn[:, :nb_views],
                                imgs=unpre_imgs[:, :nb_views],
                                depth_map_norm=depth_map_norm_input,
                                renderer_net=self.renderer,
                                other_kwarg=other_kwarg,
                                other_output=['density', 'D', 'T', 'rgb', 'alpha'],
                                have_dist=self.hparams.renderer_geneRGBsigma_dist,
                                have_gene=self.hparams.renderer_geneRGBsigma_loss,
                                use_att_3d=self.hparams.attention_3d,
                                for_mask=None if self.hparams.gene_mask=="None" else {'pts_d':pts_depth,'pts_d_gt':rays_gt_depth},
                                two_feats=(self.hparams.check_feat_mode=="2" or self.hparams.separate_occ_feat),
                                content_style_feat = {'content':content_feats, 'style':None} if self.hparams.geonerfMDMM else None,
                                input_phi=input_phi,
                                return_what_style='both' if self.hparams.styleTwoBranch else None,
                                input_z=z if self.hparams.styleTwoBranch or self.hparams.add_z else None,
                                alpha=alpha if self.hparams.tune_result_by_delta_t else None,
                            )

                            rendered_rgb.append(rend_rgb)
                            rendered_depth.append(rend_depth)
                            if self.hparams.save_sigma:
                                output_density.append(renderer_other_output['density'])
                            if self.hparams.save_nerf_rgb:
                                output_nerf_rgb.append(renderer_other_output['rgb'])
                            if self.hparams.save_T:
                                output_T.append(renderer_other_output['T'])
                            if self.hparams.save_alpha:
                                output_alpha.append(renderer_other_output['alpha'])
                            if self.hparams.save_coord:
                                output_coord.append(rays_pts_ndc['level_2'][...,0])
                            if self.hparams.geonerfMDMM:
                                if self.hparams.styleTwoBranch:
                                    output_style.append(renderer_other_output['rendered_style_rgb']['c'])
                                    output_style_cs.append(renderer_other_output['rendered_style_rgb']['c+s'])
                                else:
                                    output_style.append(renderer_other_output['rendered_style_rgb'])

                        rendered_rgb = torch.clamp(
                            torch.cat(rendered_rgb).reshape(H, W, 3).permute(2, 0, 1), 0, 1
                        )
                        rendered_depth = torch.cat(rendered_depth).reshape(H, W)

                        if self.hparams.geonerfMDMM:
                            rendered_style_rgb = torch.clamp(
                                torch.cat(output_style).reshape(H, W, 3).permute(2, 0, 1), 0, 1
                            )
                            if self.hparams.styleTwoBranch:
                                rendered_style_rgb_cs = torch.clamp(
                                    torch.cat(output_style_cs).reshape(H, W, 3).permute(2, 0, 1), 0, 1
                                )
                        rendered_rgb, rendered_depth = rendered_rgb.cpu(), rendered_depth.cpu()

                        img_vis = rendered_style_rgb_cs.cpu().detach().permute(1, 2, 0).clip(0.0, 1.0).numpy()
                        imageio.imwrite(f"{dir_name}/{id_phi}.png",(img_vis * 255).astype("uint8"))
                        
                # save input views & target view
                img_vis = unpre_imgs.cpu().squeeze(0).cpu().detach().permute(2,0,3,1).reshape(H,-1,3).clip(0.0,1.0).numpy()
                imageio.imwrite(f"{dir_name}/inputs_and_target.png",(img_vis * 255).astype("uint8"))


                    

            # print(f"Image {self.wr_cntr:02d}/{len(self.val_dataset):02d} rendered.")
            self.wr_cntr += 1

        return log

    def validation_epoch_end(self, outputs):
        mean_psnr = torch.stack([x["val_psnr"] for x in outputs]).mean()
        mean_ssim = np.stack([x["val_ssim"] for x in outputs]).mean()
        mean_lpips = np.stack([x["val_lpips"] for x in outputs]).mean()
        mask_sum = torch.stack([x["mask_sum"] for x in outputs]).sum()
        mean_d_loss_r = torch.stack([x["val_depth_loss_r"] for x in outputs]).mean()
        mean_abs_err = torch.stack([x["val_abs_err"] for x in outputs]).sum() / mask_sum
        mean_abs_err_casMVS = []
        mean_acc_1mm_casMVS, mean_acc_2mm_casMVS, mean_acc_4mm_casMVS = [], [], []
        for l in range(3):
            mask_sum_casMVS_l = torch.stack([x[f"mask_sum_casMVS_level{l}"] for x in outputs]).sum()
            mean_abs_err_casMVS.append(torch.stack([x[f"val_abs_err_casMVS_level{l}"] for x in outputs]).sum() / mask_sum_casMVS_l)
            mean_acc_1mm_casMVS.append(
                torch.stack([x[f"val_acc_{self.eval_metric[0]}mm_casMVS_level{l}"] for x in outputs]).sum()
                / mask_sum_casMVS_l
            )
            mean_acc_2mm_casMVS.append(
                torch.stack([x[f"val_acc_{self.eval_metric[1]}mm_casMVS_level{l}"] for x in outputs]).sum()
                / mask_sum_casMVS_l
            )
            mean_acc_4mm_casMVS.append(
                torch.stack([x[f"val_acc_{self.eval_metric[2]}mm_casMVS_level{l}"] for x in outputs]).sum()
                / mask_sum_casMVS_l
            )

        mean_acc_1mm = (
            torch.stack([x[f"val_acc_{self.eval_metric[0]}mm"] for x in outputs]).sum()
            / mask_sum
        )
        mean_acc_2mm = (
            torch.stack([x[f"val_acc_{self.eval_metric[1]}mm"] for x in outputs]).sum()
            / mask_sum
        )
        mean_acc_4mm = (
            torch.stack([x[f"val_acc_{self.eval_metric[2]}mm"] for x in outputs]).sum()
            / mask_sum
        )

        self.log("val/PSNR", mean_psnr, prog_bar=False)
        self.log("val/SSIM", mean_ssim, prog_bar=False)
        self.log("val/LPIPS", mean_lpips, prog_bar=False)
        if mask_sum > 0:
            self.log("val/d_loss_r", mean_d_loss_r, prog_bar=False)
            self.log("val/abs_err", mean_abs_err, prog_bar=False)
            self.log(f"val/acc_{self.eval_metric[0]}mm", mean_acc_1mm, prog_bar=False)
            self.log(f"val/acc_{self.eval_metric[1]}mm", mean_acc_2mm, prog_bar=False)
            self.log(f"val/acc_{self.eval_metric[2]}mm", mean_acc_4mm, prog_bar=False)
        if mask_sum_casMVS_l > 0:
            for l in range(3):
                self.log(f"val/abs_err_casMVS_level{l}", mean_abs_err_casMVS[l], prog_bar=False)

        txt_name = self.hparams.expname if len(self.hparams.expname.split('/'))==1 else self.hparams.expname.split('/')[0]
        with open(
            f"{self.hparams.logdir}/{self.hparams.dataset_name}/{self.hparams.expname}/{txt_name}_metrics.txt",
            "w",
        ) as metric_file:
            metric_file.write(f"PSNR: {mean_psnr}\n")
            metric_file.write(f"SSIM: {mean_ssim}\n")
            metric_file.write(f"LPIPS: {mean_lpips}\n")
            if mask_sum > 0:
                metric_file.write(f"depth_abs_err: {mean_abs_err}\n")
                metric_file.write(f"acc_{self.eval_metric[0]}mm: {mean_acc_1mm}\n")
                metric_file.write(f"acc_{self.eval_metric[1]}mm: {mean_acc_2mm}\n")
                metric_file.write(f"acc_{self.eval_metric[2]}mm: {mean_acc_4mm}\n")
            if mask_sum_casMVS_l > 0:
                for l in range(3):
                    metric_file.write(f"abs_err_casMVS_level{l}: {mean_abs_err_casMVS[l]}\n")
                for l in range(3):
                    metric_file.write(f"acc_{self.eval_metric[0]}mm_casMVS_level{l}: {mean_acc_1mm_casMVS[l]}\n")
                    metric_file.write(f"acc_{self.eval_metric[1]}mm_casMVS_level{l}: {mean_acc_2mm_casMVS[l]}\n")
                    metric_file.write(f"acc_{self.eval_metric[2]}mm_casMVS_level{l}: {mean_acc_4mm_casMVS[l]}\n")

        return



if __name__ == "__main__":
    torch.set_default_dtype(torch.float32)
    geonerf = GeoNeRF(args)

    ## Checking to logdir to see if there is any checkpoint file to continue with
    ckpt_path = f"{args.logdir}/{args.dataset_name}/{args.expname}/ckpts"
    # print(sorted(os.listdir(ckpt_path)))
    if os.path.isdir(ckpt_path) and len(os.listdir(ckpt_path)) > 0:
        ckpt_file = os.path.join(ckpt_path, sorted(os.listdir(ckpt_path))[-1])
    else:
        ckpt_file = None
    if args.ckpt != "None":
        ckpt_file = args.ckpt
    print("ckpt_file:", ckpt_file)

    ## Setting a callback to automatically save checkpoints
    checkpoint_callback = ModelCheckpoint(
        f"{args.logdir}/{args.dataset_name}/{args.expname}/ckpts",
        filename="ckpt_step-{step:06d}",
        auto_insert_metric_name=False,
        save_top_k=-1,
    )

    ## Setting up a logger
    if args.logger == "wandb":
        logger = WandbLogger(
            name=args.expname,
            project="GeoNeRF",
            save_dir=f"{args.logdir}",
            resume="allow",
            id=args.expname,
        )
    elif args.logger == "tensorboard":
        logger = loggers.TestTubeLogger(
            save_dir=f"{args.logdir}/{args.dataset_name}/{args.expname}",
            name=args.expname + "_logs",
            debug=False,
            create_git_tag=False,
            version=0
        )
    else:
        logger = None

    args.use_amp = False if args.eval else True
    if args.scene == "None":
        if args.finetune:
            val_check_interval = 1000
        elif args.only_llff:
            val_check_interval = 2000
        else:
            # max_steps = args.num_steps
            val_check_interval = 1.0
            print("epoch:",args.n_epochs)
    else:
        val_check_interval = 1
    trainer = Trainer(
        max_epochs=args.n_epochs,
        callbacks=checkpoint_callback,
        checkpoint_callback=True,
        resume_from_checkpoint=ckpt_file if not (args.finetune or args.freezeExceptTimephi or args.update_z) else None,
        logger=logger,
        progress_bar_refresh_rate=1,
        gpus=1,
        num_sanity_val_steps=1,
        val_check_interval=val_check_interval,
        check_val_every_n_epoch=1 if args.scene != 'None' else 1,
        benchmark=True,
        precision=16 if args.use_amp else 32,
        amp_level="O1",
    )

    if not args.eval:  ## Train
        if args.scene != "None":  ## Fine-tune
            if ckpt_file == None:
                if args.use_depth:
                    ckpt_file = "pretrained_weights/pretrained_w_depth.ckpt"
                else:
                    ckpt_file = "pretrained_weights/pretrained.ckpt"
                    
                load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner", strict=False)
                load_ckpt(geonerf.renderer, ckpt_file, "renderer")
            else:
                print(f"load {ckpt_file}")

        elif not args.use_depth:  ## Generalizable
            if ckpt_file == None:
                if not args.noLoadCas:
                    ## Loading the pretrained weights from Cascade MVSNet
                    torch.utils.model_zoo.load_url(
                        "https://github.com/kwea123/CasMVSNet_pl/releases/download/1.5/epoch.15.ckpt",
                        model_dir="pretrained_weights",
                    )
                    ckpt_file = "pretrained_weights/epoch.15.ckpt"
                    load_ckpt(geonerf.geo_reasoner, ckpt_file, "model", strict=False, load_costreg=args.load_costreg)
                    print("load casMVS")
            else:
                if args.finetune:
                    print(f"load ver0: {ckpt_file}")
                    print(f"freeze: {args.freeze}")
                    print(f"freeze models except casmvs: {args.update_casmvs}")
                    load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner", strict=False, freeze=args.freeze, update_casmvs=args.update_casmvs)
                    load_ckpt(geonerf.renderer, ckpt_file, "renderer", strict=False, freeze=args.freeze, update_casmvs=args.update_casmvs)
                    if args.learn_3dfeat_from_GT:
                        print(f"learn_3dfeat_from_GT: {args.learn_3dfeat_from_GT}")
                        load_ckpt(geonerf.geo_reasoner_teacher, ckpt_file, "geo_reasoner", strict=False, freeze=True)
                elif args.freezeExceptTimephi:
                    print("freeze except timephi")
                    print(f"load {ckpt_file}")
                    load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner", strict=False, freeze=True)
                    load_ckpt(geonerf.renderer, ckpt_file, "renderer", strict=False, freezeExceptTimephi=True)
                elif args.update_z:
                    print("update z, freeze everything else")
                    print(f"load {ckpt_file}")
                    load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner", strict=False, update_z=True)
                    load_ckpt(geonerf.renderer, ckpt_file, "renderer", strict=False, update_z=True)
                else:
                    print(f"load {ckpt_file}")

        trainer.fit(geonerf)
    else:  ## Eval
        geonerf = GeoNeRF(args)

        if ckpt_file is None:
            if args.use_depth:
                ckpt_file = "pretrained_weights/pretrained_w_depth.ckpt"
            else:
                ckpt_file = "pretrained_weights/pretrained.ckpt"
        
        print("ckpt_file:", ckpt_file)

        load_ckpt(geonerf.geo_reasoner, ckpt_file, "geo_reasoner")
        load_ckpt(geonerf.renderer, ckpt_file, "renderer", strict=False)

        trainer.validate(geonerf)
