# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # Copyright (c) 2020 AI葵

    # This file is part of CasMVSNet_pl.
    # CasMVSNet_pl is free software: you can redistribute it and/or modify
    # it under the terms of the GNU General Public License version 3 as
    # published by the Free Software Foundation.

    # CasMVSNet_pl is distributed in the hope that it will be useful,
    # but WITHOUT ANY WARRANTY; without even the implied warranty of
    # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    # GNU General Public License for more details.

    # You should have received a copy of the GNU General Public License
    # along with CasMVSNet_pl. If not, see <http://www.gnu.org/licenses/>.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from utils.utils import homo_warp
from inplace_abn import InPlaceABN

from utils.rendering import sigma2weights
import torchvision.transforms as T

def get_depth_values(current_depth, n_depths, depth_interval):
    depth_min = torch.clamp_min(current_depth - n_depths / 2 * depth_interval, 1e-7)
    depth_values = (
        depth_min
        + depth_interval
        * torch.arange(
            0, n_depths, device=current_depth.device, dtype=current_depth.dtype
        )[None, :, None, None]
    )
    return depth_values

def unity_regression(prob_volume, depth_values, interval):
    """
    :param interval: (b, )
    :param prob_volume: (b, d, h, w)
    :param depth_values: (b, d, h, w)
    :return: (b, h, w)
    """
    val, idx = torch.max(prob_volume, dim=1, keepdim=True)

    wta_depth = torch.gather(depth_values, 1, idx)
    offset = (1 - val) * interval

    depth = wta_depth + offset
    depth = depth.squeeze(1)

    return depth

class ConvBnReLU(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        pad=1,
        norm_act=InPlaceABN,
    ):
        super(ConvBnReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=pad,
            bias=False,
        )
        self.bn = norm_act(out_channels)

    def forward(self, x):
        return self.bn(self.conv(x))


class ConvBnReLU3D(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        pad=1,
        norm_act=InPlaceABN,
    ):
        '''actually no ReLU..'''
        super(ConvBnReLU3D, self).__init__()
        self.conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=pad,
            bias=False,
        )
        self.bn = norm_act(out_channels)

    def forward(self, x):
        return self.bn(self.conv(x))


class FeatureNet(nn.Module):
    def __init__(self, norm_act=InPlaceABN):
        super(FeatureNet, self).__init__()

        self.conv0 = nn.Sequential(
            ConvBnReLU(3, 8, 3, 1, 1, norm_act=norm_act),
            ConvBnReLU(8, 8, 3, 1, 1, norm_act=norm_act),
        )

        self.conv1 = nn.Sequential(
            ConvBnReLU(8, 16, 5, 2, 2, norm_act=norm_act),
            ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
            ConvBnReLU(16, 16, 3, 1, 1, norm_act=norm_act),
        )

        self.conv2 = nn.Sequential(
            ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act),
            ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
            ConvBnReLU(32, 32, 3, 1, 1, norm_act=norm_act),
        )

        self.toplayer = nn.Conv2d(32, 32, 1)
        self.lat1 = nn.Conv2d(16, 32, 1)
        self.lat0 = nn.Conv2d(8, 32, 1)

        # to reduce channel size of the outputs from FPN
        self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
        self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)

    def _upsample_add(self, x, y):
        return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y

    def forward(self, x, dummy=None):
        # x: (B, 3, H, W)
        conv0 = self.conv0(x)  # (B, 8, H, W)
        conv1 = self.conv1(conv0)  # (B, 16, H//2, W//2)
        conv2 = self.conv2(conv1)  # (B, 32, H//4, W//4)
        feat2 = self.toplayer(conv2)  # (B, 32, H//4, W//4)
        feat1 = self._upsample_add(feat2, self.lat1(conv1))  # (B, 32, H//2, W//2)
        feat0 = self._upsample_add(feat1, self.lat0(conv0))  # (B, 32, H, W)

        # reduce output channels
        feat1 = self.smooth1(feat1)  # (B, 16, H//2, W//2)
        feat0 = self.smooth0(feat0)  # (B, 8, H, W)

        feats = {"level_0": feat0, "level_1": feat1, "level_2": feat2}

        return feats


class CostRegNet(nn.Module):
    def __init__(self, in_channels, norm_act=InPlaceABN):
        super(CostRegNet, self).__init__()
        self.conv0 = ConvBnReLU3D(in_channels, 8, norm_act=norm_act)

        self.conv1 = ConvBnReLU3D(8, 16, stride=2, norm_act=norm_act)
        self.conv2 = ConvBnReLU3D(16, 16, norm_act=norm_act)

        self.conv3 = ConvBnReLU3D(16, 32, stride=2, norm_act=norm_act)
        self.conv4 = ConvBnReLU3D(32, 32, norm_act=norm_act)

        self.conv5 = ConvBnReLU3D(32, 64, stride=2, norm_act=norm_act)
        self.conv6 = ConvBnReLU3D(64, 64, norm_act=norm_act)

        self.conv7 = nn.Sequential(
            nn.ConvTranspose3d(
                64, 32, 3, padding=1, output_padding=1, stride=2, bias=False
            ),
            norm_act(32),
        )

        self.conv9 = nn.Sequential(
            nn.ConvTranspose3d(
                32, 16, 3, padding=1, output_padding=1, stride=2, bias=False
            ),
            norm_act(16),
        )

        self.conv11 = nn.Sequential(
            nn.ConvTranspose3d(
                16, 8, 3, padding=1, output_padding=1, stride=2, bias=False
            ),
            norm_act(8),
        )

        self.br1 = ConvBnReLU3D(8, 8, norm_act=norm_act)
        self.br2 = ConvBnReLU3D(8, 8, norm_act=norm_act)

        self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1)

    def forward(self, x):
        if x.shape[-2] % 8 != 0 or x.shape[-1] % 8 != 0:
            pad_h = 8 * (x.shape[-2] // 8 + 1) - x.shape[-2]
            pad_w = 8 * (x.shape[-1] // 8 + 1) - x.shape[-1]
            x = F.pad(x, (0, pad_w, 0, pad_h), mode="constant", value=0)
        else:
            pad_h = 0
            pad_w = 0

        conv0 = self.conv0(x)
        conv2 = self.conv2(self.conv1(conv0))
        conv4 = self.conv4(self.conv3(conv2))

        x = self.conv6(self.conv5(conv4))
        x = conv4 + self.conv7(x)
        del conv4
        x = conv2 + self.conv9(x)
        del conv2
        x = conv0 + self.conv11(x)
        del conv0
        ####################
        x1 = self.br1(x)
        with torch.enable_grad():
            x2 = self.br2(x)
        ####################
        p = self.prob(x1)

        if pad_h > 0 or pad_w > 0:
            x2 = x2[..., :-pad_h, :-pad_w]
            p = p[..., :-pad_h, :-pad_w]

        return x2, p


class FeatureSelfAttention(nn.Module):
    def __init__(self, norm_act=InPlaceABN):
        super(FeatureSelfAttention, self).__init__()

        self.conv1 = nn.Sequential(
            ConvBnReLU(8, 16, 7, 4, 3, norm_act=norm_act), # //4
            ConvBnReLU(16, 32, 7, 4, 3, norm_act=norm_act), # //4
            ConvBnReLU(32, 64, 5, 2, 2, norm_act=norm_act), # //2
        )

        from model.self_attn_renderer import EncoderLayer
        dim = 64
        d_inner = dim
        n_head = 4
        d_k = dim // n_head
        d_v = dim // n_head
        self.att = EncoderLayer(dim, d_inner, n_head, d_k, d_v)

        # to reduce channel size of the outputs from FPN
        self.smooth = nn.Conv2d(64, 16, 1, padding=0)

    def _upsample_add(self, x, y):
        return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y

    def forward(self, x, dummy=None):
        # x: (B*V, 8, H, W)
        feat = self.conv1(x)  # (B*V, 32, H//32, W//32)

        B_V, C, h, w = feat.shape
        if h*w > 200:
            feat = F.interpolate(feat, scale_factor=0.5, mode="bilinear", align_corners=True)
            B, C, h, w = feat.shape
        feat = feat.permute(0, 2, 3, 1).reshape(B_V, -1, C)
        feat = self.att(feat)[0].reshape(B_V, h, w, C).permute(0, 3, 1, 2)

        feat = self.smooth(feat)

        return feat

class contentFeatureNet(nn.Module):
    def __init__(self, norm_act=InPlaceABN, kernel=1, pad=0, pyramid=False):
        super(contentFeatureNet, self).__init__()
        
        self.pyramid = pyramid

        in_ch = [64, 128, 256]
        out_ch = [8, 16, 32]

        for l in range(3):
            conv_l = ConvBnReLU(in_ch[l], out_ch[l], kernel, 1, pad, norm_act=norm_act)
            setattr(self, f"conv_{l}", conv_l)
        
        if self.pyramid:
            self.toplayer = nn.Conv2d(32, 32, 1)
            self.lat1 = nn.Conv2d(16, 32, 1)
            self.lat0 = nn.Conv2d(8, 32, 1)

            # to reduce channel size of the outputs from FPN
            self.smooth1 = nn.Conv2d(32, 16, 3, padding=1)
            self.smooth0 = nn.Conv2d(32, 8, 3, padding=1)

    def _upsample_add(self, x, y):
        return F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) + y

    def forward(self, content):
        
        conv0 = self.conv_0(content[f"level_0"])  # (B, 8, H, W)
        conv1 = self.conv_1(content[f"level_1"])  # (B, 16, H//2, W//2)
        conv2 = self.conv_2(content[f"level_2"])  # (B, 32, H//4, W//4)

        feats = {"level_0": conv0, "level_1": conv1, "level_2": conv2}

        if self.pyramid:
            feat2 = self.toplayer(conv2)  # (B, 32, H//4, W//4)
            feat1 = self._upsample_add(feat2, self.lat1(conv1))  # (B, 32, H//2, W//2)
            feat0 = self._upsample_add(feat1, self.lat0(conv0))  # (B, 32, H, W)

            # reduce output channels
            feat1 = self.smooth1(feat1)  # (B, 16, H//2, W//2)
            feat0 = self.smooth0(feat0)  # (B, 8, H, W)

            feats = {}
            feats = {"level_0": feat0, "level_1": feat1, "level_2": feat2}

        return feats

class SPAdaIN(nn.Module):
    def __init__(self,norm,input_nc,planes):
        super(SPAdaIN,self).__init__()
        self.conv_weight = nn.Linear(input_nc, planes)
        self.conv_bias = nn.Linear(input_nc, planes)
        self.norm = norm(planes)
    
    def forward(self,x,addition):

        x = self.norm(x)
        weight = self.conv_weight(addition).reshape(1,-1,1,1,1)
        bias = self.conv_bias(addition).reshape(1,-1,1,1,1)
        out =  weight * x + bias

        return out

class SPAdaINResBlock(nn.Module):
    def __init__(self,input_nc,planes,norm=nn.InstanceNorm3d,conv_kernel_size=1,padding=0):
        super(SPAdaINResBlock,self).__init__()
        self.spadain1 = SPAdaIN(norm=norm,input_nc=input_nc,planes=planes)
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv3d(planes, planes, kernel_size=conv_kernel_size, stride=1, padding=padding)
        self.spadain2 = SPAdaIN(norm=norm,input_nc=input_nc,planes=planes)
        self.conv2 = nn.Conv3d(planes,planes,kernel_size=conv_kernel_size, stride=1, padding=padding)
        self.spadain_res = SPAdaIN(norm=norm,input_nc=input_nc,planes=planes)
        self.conv_res=nn.Conv3d(planes,planes,kernel_size=conv_kernel_size, stride=1, padding=padding)

    def forward(self,x,addition):

        out = self.spadain1(x,addition)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.spadain2(out,addition)
        out = self.relu(out)
        out = self.conv2(out)

        residual = x
        residual = self.spadain_res(residual,addition)
        residual = self.relu(residual)
        residual = self.conv_res(residual)

        out = out + residual

        return out

class styleAdain(nn.Module):
    def __init__(self, input_dim, output_dim, norm_act=InPlaceABN):
        super(styleAdain, self).__init__()

        self.conv1 = nn.Conv3d(input_dim, input_dim//2, 1, stride=1, padding=0)
        self.conv2 = nn.Conv3d(input_dim//2, input_dim//2, 1, stride=1, padding=0)
        self.conv3 = nn.Conv3d(input_dim//2, output_dim, 3, stride=1, padding=1)

        # self.conv1 = nn.Conv3d(input_dim, input_dim//2, 1, stride=1, padding=0)
        # self.conv2 = nn.Conv3d(input_dim//2, output_dim, 3, stride=1, padding=1)

        self.spadain_block1 = SPAdaINResBlock(input_nc=8, planes=input_dim//2)
        self.spadain_block2 = SPAdaINResBlock(input_nc=8, planes=input_dim//2)

        # self.norm1 = nn.InstanceNorm3d(input_dim)
        # self.norm2 = nn.InstanceNorm3d(output_dim)

    def forward(self, content, style):
        x = self.conv1(content)
        x = self.spadain_block1(x, style)
        x = self.conv2(x)
        x = self.spadain_block2(x, style)
        x = self.conv3(x)

        return x

class CasMVSNet(nn.Module):
    def __init__(self, num_groups=8, norm_act=InPlaceABN, levels=3, use_depth=False, use_disocclude=False, geoFeatComplete="None", texFeatComplete=False, texFeat=False, nb_views=3, pDensity=False, upperbound=False, upperbound_noise=False, upperbound_gauss=False, D01=False, D_gauss=False, O_label=False, P_constraint=False, unimvs=False, use_featSA=False, use_global_geoFeat=False, check_feat_mode="None", separate_occ_feat=False, is_teacher=False,
                 cas_confi=False, texFeat_woUnet=False, save_var_confi=-1,
                 geonerfMDMM=False, style3Dfeat=False, styleTwoBranch=False, unparallExtract=False, contentPyramid=False, contentFeature=False):
        super(CasMVSNet, self).__init__()
        self.levels = levels  # 3 depth levels
        self.n_depths = [8, 32, 48]
        self.interval_ratios = [1, 2, 4]
        self.use_depth = use_depth
        self.use_disocclude = use_disocclude
        self.upperbound = upperbound
        self.upperbound_noise = upperbound_noise
        self.upperbound_gauss = upperbound_gauss
        self.D_gauss = D_gauss
        self.D01 = D01
        self.geoFeatComplete = geoFeatComplete
        self.texFeatComplete = texFeatComplete
        self.texFeat = texFeat
        self.texFeat_woUnet = texFeat_woUnet
        self.pDensity = pDensity
        self.O_label = O_label
        self.P_constraint = P_constraint
        self.unimvs = unimvs
        self.use_featSA = use_featSA
        self.use_global_geoFeat = use_global_geoFeat
        self.check_feat_mode = check_feat_mode
        self.separate_occ_feat = separate_occ_feat
        self.is_teacher = is_teacher
        self.cas_confi = cas_confi
        self.save_var_confi = save_var_confi
        self.geonerfMDMM = geonerfMDMM
        self.style3Dfeat = style3Dfeat
        self.styleTwoBranch = styleTwoBranch
        if self.style3Dfeat or self.styleTwoBranch: assert self.geonerfMDMM == True
        self.unparallExtract = unparallExtract
        self.contentPyramid = contentPyramid
        self.contentFeature = contentFeature

        self.G = num_groups  # number of groups in groupwise correlation
        self.feature = FeatureNet()
        if use_featSA:
            self.feat_self_attention = FeatureSelfAttention()
        if use_global_geoFeat:
            self.geo_global_conv = nn.Sequential(
                ConvBnReLU(8, 16, 7, 4, 3, norm_act=norm_act), # //4
                ConvBnReLU(16, 32, 5, 2, 2, norm_act=norm_act), # //2
            )

        for l in range(self.levels):
            if l == self.levels - 1 and self.use_depth:
                cost_reg_l = CostRegNet(self.G + 1, norm_act)
            else:
                cost_reg_l = CostRegNet(self.G, norm_act)

            setattr(self, f"cost_reg_{l}", cost_reg_l)
        
        if self.check_feat_mode == "1":
            self.cost_reg_2_B = CostRegNet(self.G, norm_act)
        
        if self.check_feat_mode == "2" or self.separate_occ_feat:
            for l in range(self.levels):
                cost_reg_l_B = CostRegNet(self.G, norm_act)
                setattr(self, f"cost_reg_{l}_B", cost_reg_l_B)
        
        if self.use_disocclude:
            for l in range(self.levels):
                # disocc_conv_l = nn.Conv3d(1, 1, 3, stride=1, padding=1)
                disocc_conv_l =  ConvBnReLU3D(1, 1, norm_act=norm_act)
                setattr(self, f"disocc_conv_{l}", disocc_conv_l)
        
        if self.geoFeatComplete == 'v1':
            k = [3, 5, 7]
            p = [1, 2, 3]
            for l in range(self.levels):
                # geo_feat_conv_l = nn.Sequential(
                #                     # nn.Conv3d(8, 8, 3, stride=1, padding=1),
                #                     # nn.Conv3d(8, 8, 3, stride=1, padding=1),
                #                     ConvBnReLU3D(8, 8, norm_act=norm_act),
                #                     ConvBnReLU3D(8, 8, norm_act=norm_act)
                # )
                geo_feat_conv_l = nn.Sequential(
                                    ConvBnReLU3D(8, 8, kernel_size=(k[l],1,1), pad=(p[l],0,0), norm_act=norm_act),
                                    ConvBnReLU3D(8, 8, kernel_size=(k[l],1,1), pad=(p[l],0,0), norm_act=norm_act)
                )
                setattr(self, f"geo_feat_conv_{l}", geo_feat_conv_l)
        elif self.geoFeatComplete == 'v2':
            for l in range(self.levels):
                geo_feat_conv_l = nn.Sequential(
                                    # nn.Conv3d(8, 8, 3, stride=1, padding=1),
                                    # nn.Conv3d(8, 8, 3, stride=1, padding=1),
                                    ConvBnReLU3D(16, 8, kernel_size=1, pad=0, norm_act=norm_act),
                                    ConvBnReLU3D(8, 8, kernel_size=1, pad=0, norm_act=norm_act)
                )
                setattr(self, f"geo_feat_conv_{l}", geo_feat_conv_l)
        
        if self.texFeat:
            for l in range(self.levels):
                tex_unet_l = CostRegNet(self.G + (nb_views-1)*3, norm_act)
                setattr(self, f"tex_unet_{l}", tex_unet_l)
        
        if self.texFeatComplete:
            #TODO: kenerl not (k,1,1)
            for l in range(self.levels):
                tex_feat_conv_l = nn.Sequential(
                                        # ConvBnReLU3D(8, 8, norm_act=norm_act),
                                        # ConvBnReLU3D(8, 8, norm_act=norm_act)
                                        ConvBnReLU3D(8, 8, kernel_size=1, pad=0, norm_act=norm_act),
                                        ConvBnReLU3D(8, 8, kernel_size=1, pad=0, norm_act=norm_act)
                )
                setattr(self, f"tex_feat_conv_{l}", tex_feat_conv_l)
        
        if self.O_label:
            for l in range(self.levels):
                O_label_conv_l = nn.Sequential(
                                    # nn.Conv3d(8, 8, 3, stride=1, padding=1),
                                    # nn.Conv3d(8, 8, 3, stride=1, padding=1),
                                    ConvBnReLU3D(8, 4, kernel_size=1, pad=0, norm_act=norm_act),
                                    ConvBnReLU3D(4, 1, kernel_size=1, pad=0, norm_act=norm_act)
                )
                setattr(self, f"O_label_conv_{l}", O_label_conv_l)
        
        # if self.cas_confi:
        #     self.var2confi = nn.Sequential( 
        #                         nn.Conv2d(1,1,1,stride=1),
        #     )

        if self.geonerfMDMM or self.contentFeature:
            self.content_feat_fc = contentFeatureNet(pyramid=self.contentPyramid)
            if self.styleTwoBranch:
                if self.unparallExtract:
                    self.content_special_feat_fc = contentFeatureNet(kernel=3,pad=1,pyramid=self.contentPyramid)
                else:
                    self.content_special_feat_fc = contentFeatureNet(pyramid=self.contentPyramid)
                
        
        if self.style3Dfeat:
            # input_dims = [64+8, 128+8, 256+8] # content_feat concat cost_reg_feat
            input_dims = [8+8, 16+8, 32+8] # content_feat concat cost_reg_feat
            output_dims = [8, 16, 32]
            for l in range(self.levels):
                style_adain_l = styleAdain(input_dim=input_dims[l], output_dim=output_dims[l])
                setattr(self, f"style_adain_{l}", style_adain_l)

    def build_cost_volumes(self, feats, affine_mats, affine_mats_inv, depth_values, idx, spikes):
        B, V, C, H, W = feats.shape
        D = depth_values.shape[1]

        ref_feats, src_feats = feats[:, idx[0]], feats[:, idx[1:]]
        src_feats = src_feats.permute(1, 0, 2, 3, 4)  # (V-1, B, C, h, w)

        affine_mats_inv = affine_mats_inv[:, idx[0]]
        affine_mats = affine_mats[:, idx[1:]]
        affine_mats = affine_mats.permute(1, 0, 2, 3)  # (V-1, B, 3, 4)

        ref_volume = ref_feats.unsqueeze(2).repeat(1, 1, D, 1, 1)  # (B, C, D, h, w)

        ref_volume = ref_volume.view(B, self.G, C // self.G, *ref_volume.shape[-3:])
        volume_sum = 0

        for i in range(len(idx) - 1):
            proj_mat = (affine_mats[i].double() @ affine_mats_inv.double()).float()[
                :, :3
            ]
            warped_volume, grid = homo_warp(src_feats[i], proj_mat, depth_values)

            warped_volume = warped_volume.view_as(ref_volume)
            volume_sum = volume_sum + warped_volume  # (B, G, C//G, D, h, w)

        volume = (volume_sum * ref_volume).mean(dim=2) / (V - 1)

        if spikes is None:
            output = volume
        else:
            output = torch.cat([volume, spikes], dim=1)

        return output

    def create_neural_volume(
        self,
        feats,
        affine_mats,
        affine_mats_inv,
        idx,
        init_depth_min,
        depth_interval,
        gt_depths,
        imgs=None,
        gt_depths_real=None,
        view_id=None,
        content_style_feat=None,
    ):
        if feats["level_0"].shape[-1] >= 800:
            hres_input = True
        else:
            hres_input = False

        B, V = affine_mats.shape[:2]

        if self.check_feat_mode == '2' or self.separate_occ_feat:
            v_feat = {'A':{},'B':{}}
        else:
            v_feat = {}
        v_label_feat = {}
        depth_maps = {}
        depth_values = {}
        depth_probs = {}
        disocclude_confi = {}
        pDensity_output = {}
        weights_mvs_output = {}
        other_vol_output = {}
        depth_interval_levels = {}
        texFeat_outputs = {}
        depth_confi = {}
        style3D_outputs = {}
        for l in reversed(range(self.levels)):  # (2, 1, 0)
            feats_l = feats[f"level_{l}"]  # (B*V, C, h, w)
            feats_l = feats_l.view(B, V, *feats_l.shape[1:])  # (B, V, C, h, w)
            h, w = feats_l.shape[-2:]
            depth_interval_l = depth_interval * self.interval_ratios[l]
            D = self.n_depths[l]
            if l == self.levels - 1:  # coarsest level
                depth_values_l = init_depth_min + depth_interval_l * torch.arange(
                    0, D, device=feats_l.device, dtype=feats_l.dtype
                )  # (D)
                depth_values_l = depth_values_l[None, :, None, None].expand(
                    -1, -1, h, w
                )

                if self.use_depth:
                    gt_mask = gt_depths > 0
                    sp_idx_float = (
                        gt_mask * (gt_depths - init_depth_min) / (depth_interval_l)
                    )[:, :, None]
                    spikes = (
                        torch.arange(D).view(1, 1, -1, 1, 1).cuda()
                        == sp_idx_float.floor().long()
                    ) * (1 - sp_idx_float.frac())
                    spikes = spikes + (
                        torch.arange(D).view(1, 1, -1, 1, 1).cuda()
                        == sp_idx_float.ceil().long()
                    ) * (sp_idx_float.frac())
                    spikes = (spikes * gt_mask[:, :, None]).float()
            else:
                depth_lm1 = depth_l.detach()  # the depth of previous level
                depth_lm1 = F.interpolate(
                    depth_lm1, scale_factor=2, mode="bilinear", align_corners=True
                )  # (B, 1, h, w)
                depth_values_l = get_depth_values(depth_lm1, D, depth_interval_l)

            affine_mats_l = affine_mats[..., l]
            affine_mats_inv_l = affine_mats_inv[..., l]

            if l == self.levels - 1 and self.use_depth:
                spikes_ = spikes
            else:
                spikes_ = None

            if hres_input:
                v_feat_l = checkpoint(
                    self.build_cost_volumes,
                    feats_l,
                    affine_mats_l,
                    affine_mats_inv_l,
                    depth_values_l,
                    idx,
                    spikes_,
                    preserve_rng_state=False,
                )
            else:
                v_feat_l = self.build_cost_volumes(
                    feats_l,
                    affine_mats_l,
                    affine_mats_inv_l,
                    depth_values_l,
                    idx,
                    spikes_,
                )

            if self.texFeat:
                h, w = v_feat_l.shape[-2:]
                imgs_feat = torch.empty((1, len(idx)*3, D, h, w), device=v_feat_l.device, dtype=torch.float)
                T_resize = T.Resize((h,w))
                imgs_resize = torch.empty((B,V,3,h,w), device=v_feat_l.device, dtype=torch.float)
                for v in range(V):
                    imgs_resize[0,v] = T_resize(imgs[0,v])
                imgs_resize = imgs_resize.view(B, V, -1, h, w).permute(1, 0, 2, 3, 4)
                imgs_feat[:,:3,...] = imgs_resize[idx[0]].unsqueeze(2).expand(-1, -1, D, -1, -1)
                affine_mats_inv_l_ref = affine_mats_inv_l[:, idx[0]]
                #TODO: should be range(1,len(idx))
                for i in range(len(idx) - 1):
                    proj_mat = (affine_mats_l[:, idx[i]].double() @ affine_mats_inv_l_ref.double()).float()[
                        :, :3
                    ]
                    imgs_feat[:, (i+1)*3:(i+2)*3,...], _ = homo_warp(imgs_resize[idx[i]], proj_mat, depth_values_l)
                
                texFeat_input = torch.cat([v_feat_l, imgs_feat], dim=1) # (B, c, D, h, w)
                tex_unet_l = getattr(self, f"tex_unet_{l}")
                texFeat_output_l, _ = tex_unet_l(texFeat_input)
            
            if self.texFeat_woUnet:
                h, w = v_feat_l.shape[-2:]
                imgs_feat = torch.empty((1, 3, D, h, w), device=v_feat_l.device, dtype=torch.float)
                T_resize = T.Resize((h,w))
                imgs_resize = torch.empty((B,V,3,h,w), device=v_feat_l.device, dtype=torch.float)
                for v in range(V):
                    imgs_resize[0,v] = T_resize(imgs[0,v])
                imgs_resize = imgs_resize.view(B, V, -1, h, w).permute(1, 0, 2, 3, 4)
                imgs_feat[:,:3,...] = imgs_resize[idx[0]].unsqueeze(2).expand(-1, -1, D, -1, -1)
                
                texFeat_output_l = torch.cat([v_feat_l, imgs_feat], dim=1) # (B, c, D, h, w)

            cost_reg_l = getattr(self, f"cost_reg_{l}")
            if self.check_feat_mode == "2"  or self.separate_occ_feat: # feat_mode=2: actually only use level 2 also
                v_feat_l_A, depth_prob = cost_reg_l(v_feat_l)  # depth_prob: (B, 1, D, h, w)
                cost_reg_l_B = getattr(self, f"cost_reg_{l}_B")
                v_feat_l_B, _ = cost_reg_l_B(v_feat_l)  # depth_prob: (B, 1, D, h, w)
            elif self.check_feat_mode == "1" and l == 2:
                _, depth_prob = cost_reg_l(v_feat_l)  # depth_prob: (B, 1, D, h, w)
                v_feat_l, _ = self.cost_reg_2_B(v_feat_l)  # depth_prob: (B, 1, D, h, w)
            else:
                v_feat_l, depth_prob = cost_reg_l(v_feat_l)  # depth_prob: (B, 1, D, h, w)

            if self.style3Dfeat:
                assert content_style_feat != None
                # concat content_feat
                # content_feat_l = content_style_feat['content'][f'level_{l}'].unsqueeze(0) # (1, V, c, h, w)
                content_feat_l = feats[f'level_{l}'].unsqueeze(0) # (1, V, c, h, w)
                content_dim = content_feat_l.shape[2]
                h, w = v_feat_l.shape[-2:]
                concat_len = 1 #len(idx)
                contents_feat_stack = torch.empty((1, concat_len*content_dim, D, h, w), device=v_feat_l.device, dtype=torch.float)
                content_feat_l = content_feat_l.view(B, V, -1, h, w).permute(1, 0, 2, 3, 4)
                contents_feat_stack[:,:content_dim,...] = content_feat_l[idx[0]].unsqueeze(2).expand(-1, -1, D, -1, -1)
                affine_mats_inv_l_ref = affine_mats_inv_l[:, idx[0]]
                for i in range(1,concat_len):
                    proj_mat = (affine_mats_l[:, idx[i]].double() @ affine_mats_inv_l_ref.double()).float()[
                        :, :3
                    ]
                    contents_feat_stack[:, (i)*content_dim:(i+1)*content_dim,...], _ = homo_warp(content_feat_l[idx[i]], proj_mat, depth_values_l)
                geo_content_input = torch.cat([v_feat_l, contents_feat_stack], dim=1) # (B, c, D, h, w)
                style_adain_l = getattr(self, f"style_adain_{l}")
                style_output_l = style_adain_l(geo_content_input, content_style_feat['style']['style_feat'])

            if self.pDensity:
                _alpha_mvs, _depth_cdf_l, _weights_mvs = sigma2weights(torch.relu(depth_prob).permute(0,1,3,4,2), return_all=True)
                alpha_mvs, depth_cdf_l, weights_mvs = _alpha_mvs.permute(0,1,4,2,3), _depth_cdf_l.permute(0,1,4,2,3), _weights_mvs.permute(0,1,4,2,3)
                if self.use_disocclude:
                    disocclude_confi[f"level_{l}"] = depth_cdf_l

                pDensity_output[f"level_{l}"] = torch.relu(depth_prob)
                weights_mvs_output[f"level_{l}"] = weights_mvs
                depth_l = (weights_mvs * depth_values_l[:, None]).sum(
                    dim=2
                ) # (B, 1, h, w)

            elif self.unimvs:
                depth_l = unity_regression(torch.sigmoid(depth_prob).squeeze(0), depth_values_l, depth_interval_l).unsqueeze(1) # (B, 1, h, w)

            else:
                depth_l = (F.softmax(depth_prob, dim=2) * depth_values_l[:, None]).sum(
                    dim=2
                ) # (B, 1, h, w)
            
            if self.cas_confi:
                _depth_prob = F.softmax(depth_prob, dim=2)
                E_X = (_depth_prob * depth_values_l[:, None]).sum(dim=2) # (B, 1, h, w)
                E_X2 = (_depth_prob * (depth_values_l[:, None]**2)).sum(dim=2) # (B, 1, h, w)
                var_l = E_X2 + E_X**2
                # depth_confi_l = self.var2confi(var_l)
                depth_confi_l = -var_l
                confi_shape = depth_confi_l.shape
                tmp_depth_confi_l = depth_confi_l.reshape(-1)
                tmp_depth_confi_l = F.softmax(tmp_depth_confi_l)
                depth_confi_l = tmp_depth_confi_l.reshape(confi_shape)

                if self.save_var_confi != -1:
                    if l == self.save_var_confi: 
                        save_var = var_l
                    if l == 0:
                        depth_l = save_var

            if self.use_disocclude and not self.pDensity:
                if self.D01:
                    repeat_depth_l = depth_l.unsqueeze(0).repeat(1, 1, self.n_depths[l], 1, 1)
                    d_diff = depth_values_l - repeat_depth_l
                    depth_cdf_l = (d_diff <= 0).float()

                elif self.D_gauss:
                    mean_depth_values_l = (torch.max(depth_values_l, dim=1, keepdim=True)[0] + torch.min(depth_values_l, dim=1, keepdim=True)[0]) / 2
                    reverse_depth_l = (depth_l - mean_depth_values_l.squeeze(0))*(-1) + mean_depth_values_l.squeeze(0)
                    std_map = torch.zeros_like(depth_l) + 0.15
                    if torch.sum(torch.isnan(reverse_depth_l)) > 0:
                        print("reverse_depth_l",torch.sum(torch.isnan(reverse_depth_l)))
                        print("depth_l",torch.sum(torch.isnan(depth_l)))
                        print("mean_depth_val",torch.sum(torch.isnan(mean_depth_values_l)))
                        print("depth_prob",torch.sum(torch.isnan(depth_prob)))
                        print("depth_prob 0",torch.sum(torch.sum(depth_prob, dim=2)==0))
                        print("depth_prob softmax",torch.sum(torch.isnan(F.softmax(depth_prob, dim=2))))
                    
                    gauss = torch.distributions.normal.Normal(reverse_depth_l, std_map)
                    
                    reverse_depth_values_l = (depth_values_l - mean_depth_values_l)*(-1) + mean_depth_values_l
                    depth_cdf_l = gauss.cdf(reverse_depth_values_l).unsqueeze(0)

                else:
                    depth_prob_softmax = F.softmax(depth_prob, dim=2)
                    depth_cdf = torch.zeros_like(depth_prob)
                    for d in range(depth_prob.shape[2]):
                        depth_cdf[:,:,d,:,:] = torch.sum(depth_prob_softmax[:,:,d:,:,:],dim=2)
                    # disocc_conv_l = getattr(self, f"disocc_conv_{l}")
                    # depth_cdf = disocc_conv_l(depth_cdf)
                    # depth_cdf_l = nn.Sigmoid()(depth_cdf)
                    depth_cdf_l = depth_cdf

                if self.upperbound and gt_depths_real != {}:
                    if self.upperbound_noise:
                        _gt_mask = gt_depths_real[f"level_{l}"] > 0
                        # gt_depths_cur = gt_depths_real[f"level_{l}"] + _gt_mask*(torch.rand_like(gt_depths_real[f"level_{l}"])-0.5)
                        gt_depths_cur = gt_depths_real[f"level_{l}"] + _gt_mask*(torch.rand_like(gt_depths_real[f"level_{l}"])*2-1)
                    else:
                        # gt_depths_cur = gt_depths_real[f"level_{l}"]
                        ## test not use GT, use predicted depth ##
                        if self.is_teacher:
                            gt_depths_cur = gt_depths_real[f"level_{l}"]
                        else:
                            gt_depths_cur = gt_depths_real[f"level_{l}"]
                            # if gt_depths_real != {}:
                            #     gt_depths_cur = torch.zeros_like(depth_l)
                            #     gt_depths_cur[gt_depths_real[f"level_{l}"]>0] = depth_l[gt_depths_real[f"level_{l}"]>0]
                            # else:
                            #     gt_depths_cur = depth_l
                        ## test not use GT, use predicted depth ##

                        ## to test edge:predict depth/gt depth begin ##
                        # if l == 0:
                        # import cv2
                        # import numpy as np
                        # img = (imgs[0, view_id].detach().cpu().permute(1,2,0).numpy() * 255).astype(np.uint8)
                        # img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).astype(np.uint8)
                        # img_blur = cv2.GaussianBlur(img_gray, (3,3), 0) 
                        # edges = cv2.Canny(image=img_blur, threshold1=100, threshold2=200)

                        # kernel = np.ones((3,3), np.uint8)
                        # dilate_iter = 2
                        # edges = cv2.dilate(edges, kernel, iterations = dilate_iter)
                        # edges = torch.from_numpy(edges.astype(float) / 255).float().cuda()
                        # edges = cv2.resize(edges.float().detach().cpu().numpy(), None, fx=1.0/(2**l), fy=1.0/(2**l), interpolation=cv2.INTER_NEAREST,)
                        # edges = torch.from_numpy(edges).cuda()

                        # gt_depths_cur = depth_l*edges + gt_depths_real[f"level_{l}"]*(edges-1)*(-1)
                        # gt_depths_cur = depth_l*(edges-1)*(-1) + gt_depths_real[f"level_{l}"]*edges
                        # gt_depths_cur = gt_depths_real[f"level_{l}"]*(edges-1)*(-1)
                        # gt_depths_cur = gt_depths_real[f"level_{l}"]*edges
                        # _gt_mask = (gt_depths_real[f"level_{l}"] > 0).float()
                        # gt_depths_cur = depth_l*_gt_mask
                        # gt_depths_cur = gt_depths_real[f"level_{l}"]*_gt_mask + (-1)*(_gt_mask-1)*depth_l
                        ## to test edge:predict depth/gt depth end ##

                    repeat_gt_d = gt_depths_cur.repeat(1, self.n_depths[l], 1, 1)
                    d_diff = (depth_values_l - repeat_gt_d)
                    no_gt_mask, gt_mask = (repeat_gt_d == 0), (repeat_gt_d != 0)
                    
                    if self.P_constraint:
                        depth_gt_l = gt_depths_cur
                        lower_depth_mask = (d_diff <= 0)
                        lower_depth_mask_2 = torch.zeros_like(d_diff)
                        lower_depth_mask_2[d_diff > 0] = -1e10
                        lower_depth_idx = torch.argmax(d_diff * lower_depth_mask + lower_depth_mask_2, dim=1, keepdim=True)
                        lower_tmp = (d_diff * lower_depth_mask + lower_depth_mask_2).gather(1, lower_depth_idx)
                        lower_depth = depth_values_l.gather(1, lower_depth_idx)

                        higher_depth_mask = (d_diff >= 0)
                        higher_depth_mask_2 = torch.zeros_like(d_diff)
                        higher_depth_mask_2[d_diff < 0] = 1e10
                        higher_depth_idx = torch.argmin(d_diff * higher_depth_mask + higher_depth_mask_2, dim=1, keepdim=True)
                        higher_tmp = (d_diff * higher_depth_mask + higher_depth_mask_2).gather(1, higher_depth_idx)
                        higher_depth = depth_values_l.gather(1, higher_depth_idx)

                        check_low_mask = (lower_tmp == -1e10)
                        check_high_mask = (higher_tmp == 1e10)
                        x_low, x_high = torch.zeros_like(depth_gt_l), torch.zeros_like(depth_gt_l)
                        # d1 gt d2
                        x_high[(check_low_mask+check_high_mask)==0] = ((depth_gt_l-lower_depth) / (higher_depth-lower_depth))[(check_low_mask+check_high_mask)==0]
                        x_low[(check_low_mask+check_high_mask)==0] = (1 - ((depth_gt_l-lower_depth) / (higher_depth-lower_depth)))[(check_low_mask+check_high_mask)==0]
                        
                        x_high[x_high == -float("Inf")], x_low[x_low == float("Inf")] = 0, 0
                        x_high_nan_mask, x_low_nan_mask = x_high.isnan(), x_low.isnan()
                        x_high_NOT_nan_mask, x_low_NOT_nan_mask = x_high_nan_mask*(-1)+1, x_low_nan_mask*(-1)+1
                        x_high, x_low = x_high.nan_to_num(), x_low.nan_to_num()
                        
                        # assert torch.sum(x[x == -float("Inf")]) == torch.sum(x[gt_depths_real[f"level_{l}"] == 0])
                        depth_cdf_l_no_gt = no_gt_mask * depth_cdf_l
                        depth_cdf_l_gt = gt_mask * (d_diff.unsqueeze(0) <= 0)
                        depth_cdf_l = depth_cdf_l_no_gt + depth_cdf_l_gt

                        higher_depth_cdf = depth_cdf_l.float().clone()
                        higher_depth_cdf = higher_depth_cdf.squeeze(0).scatter(1, higher_depth_idx, x_high).unsqueeze(0)
                        use_mask = (gt_mask) * ((check_low_mask+check_high_mask)==0) * x_high_NOT_nan_mask * x_low_NOT_nan_mask
                        not_use_mask = (-1) * use_mask + 1
                        depth_cdf_l_no_gt = not_use_mask * depth_cdf_l
                        depth_cdf_l_gt = use_mask * higher_depth_cdf
                        depth_cdf_l = depth_cdf_l_no_gt + depth_cdf_l_gt
                    
                    elif self.upperbound_gauss:
                        depth_cdf_l_no_gt = no_gt_mask * depth_cdf_l
                        
                        depth_gt_l = gt_depths_cur
                        mean_depth_values_l = (torch.max(depth_values_l, dim=1, keepdim=True)[0] + torch.min(depth_values_l, dim=1, keepdim=True)[0]) / 2
                        reverse_depth_gt_l = (depth_gt_l - mean_depth_values_l.squeeze(0))*(-1) + mean_depth_values_l.squeeze(0)
                        std_map = torch.zeros_like(depth_gt_l) + 0.15
                        gauss = torch.distributions.normal.Normal(reverse_depth_gt_l, std_map)
                        
                        reverse_depth_values_l = (depth_values_l - mean_depth_values_l)*(-1) + mean_depth_values_l
                        depth_cdf_l_gauss = gauss.cdf(reverse_depth_values_l)
                        
                        depth_cdf_l_gt = gt_mask * depth_cdf_l_gauss
                        depth_cdf_l = depth_cdf_l_no_gt + depth_cdf_l_gt # (1, 1, D, h, w)
                    
                    else:
                        depth_cdf_l_no_gt = no_gt_mask * depth_cdf_l
                        depth_cdf_l_gt = gt_mask * (d_diff.unsqueeze(0) <= 0)
                        depth_cdf_l = depth_cdf_l_no_gt + depth_cdf_l_gt
                        
                    
            if self.geoFeatComplete == 'v1':
                OD = v_feat_l*depth_cdf_l
                geo_feat_conv_l = getattr(self, f"geo_feat_conv_{l}")
                O_tilde = geo_feat_conv_l(OD)
                v_feat_l = OD + O_tilde*(1-depth_cdf_l)

            elif self.geoFeatComplete == 'v2':
                OD = v_feat_l*depth_cdf_l
                O_surface = torch.sum(F.softmax(depth_prob, dim=2) * v_feat_l, dim=2, keepdim=True).repeat(1,1,v_feat_l.shape[2],1,1)
                geo_feat_conv_l = getattr(self, f"geo_feat_conv_{l}")
                O_tilde = geo_feat_conv_l(torch.cat((v_feat_l, O_surface), dim=1))
                v_feat_l = OD + O_tilde*(1-depth_cdf_l)
            
            if self.texFeatComplete:
                OA = texFeat_output_l*depth_cdf_l
                tex_feat_conv_l = getattr(self, f"tex_feat_conv_{l}")
                O_tilde = tex_feat_conv_l(OA)
                texFeat_output_l = OA + O_tilde*(1-depth_cdf_l)
                
            if self.O_label:
                O_label_conv_l = getattr(self, f"O_label_conv_{l}")
                v_label_feat_l = torch.sigmoid(O_label_conv_l(v_feat_l))
                v_label_feat[f"level_{l}"] = v_label_feat_l

            if self.check_feat_mode == '2' or self.separate_occ_feat:
                v_feat["A"][f"level_{l}"] = v_feat_l_A
                v_feat["B"][f"level_{l}"] = v_feat_l_B
            else:
                v_feat[f"level_{l}"] = v_feat_l
            depth_maps[f"level_{l}"] = depth_l
            depth_values[f"level_{l}"] = depth_values_l
            depth_probs[f"level_{l}"] = depth_prob
            depth_interval_levels[f"level_{l}"] = depth_interval_l
            if self.use_disocclude:
                disocclude_confi[f"level_{l}"] = depth_cdf_l
            if self.texFeat or self.texFeat_woUnet:
                texFeat_outputs[f"level_{l}"] = texFeat_output_l
            if self.cas_confi:
                depth_confi[f"level_{l}"] = depth_confi_l
            if self.style3Dfeat:
                style3D_outputs[f"level_{l}"] = style_output_l

        other_vol_output['depth_interval_levels'] = depth_interval_levels
        other_vol_output['depth_probs'] = depth_probs
        if self.use_disocclude:
            other_vol_output['disocclude_confi'] = disocclude_confi
        if self.texFeat or self.texFeat_woUnet:
            other_vol_output['texFeat'] = texFeat_outputs
        if self.pDensity:
            other_vol_output['pDensity'] = pDensity_output
            other_vol_output['weights_mvs'] = weights_mvs_output
        if self.O_label:
            other_vol_output['O_label'] = v_label_feat
        if self.cas_confi:
            other_vol_output['depth_confi'] = depth_confi
        if self.style3Dfeat:
            other_vol_output['style3D'] = style3D_outputs

        return v_feat, depth_maps, depth_values, other_vol_output

    def forward(
        self, imgs, affine_mats, affine_mats_inv, near_far, closest_idxs, gt_depths=None, gt_depths_real=None, 
        content_feats=None, content_style_feat=None,
    ):
        B, V, _, H, W = imgs.shape

        if self.geonerfMDMM or self.contentFeature:
            feats = self.content_feat_fc(content_feats)
            feats_fpn = feats[f"level_0"].reshape(B, V, *feats[f"level_0"].shape[1:])
            if self.styleTwoBranch:
                special_feats = self.content_special_feat_fc(content_feats)
        else:
            ## Feature Pyramid
            feats = self.feature(
                imgs.reshape(B * V, 3, H, W)
            )  # (B*V, 8, H, W), (B*V, 16, H//2, W//2), (B*V, 32, H//4, W//4)
            feats_fpn = feats[f"level_0"].reshape(B, V, *feats[f"level_0"].shape[1:])

        if self.use_featSA:
            feat_SA = self.feat_self_attention(feats[f"level_0"]).unsqueeze(0)

        if self.check_feat_mode == '2' or self.separate_occ_feat:
            feats_vol = {"A":{"level_0": [], "level_1": [], "level_2": []}, 'B':{"level_0": [], "level_1": [], "level_2": []}}
        else:
            feats_vol = {"level_0": [], "level_1": [], "level_2": []}
        depth_map = {"level_0": [], "level_1": [], "level_2": []}
        depth_values = {"level_0": [], "level_1": [], "level_2": []}
        disocc_confi = {"level_0": [], "level_1": [], "level_2": []}
        pDensity_output = {"level_0": [], "level_1": [], "level_2": []}
        weights_mvs_output = {"level_0": [], "level_1": [], "level_2": []}
        O_label_val_output = {"level_0": [], "level_1": [], "level_2": []}
        depth_probs_output = {"level_0": [], "level_1": [], "level_2": []}
        depth_interval_levels = {"level_0": [], "level_1": [], "level_2": []}
        texFeat_outputs = {"level_0": [], "level_1": [], "level_2": []}
        depth_confi_outputs = {"level_0": [], "level_1": [], "level_2": []}
        style3D_outputs = {"level_0": [], "level_1": [], "level_2": []}
        ## Create cost volumes for each view
        for i in range(0, V):
            permuted_idx = torch.tensor(closest_idxs[0, i]).cuda()

            init_depth_min = near_far[0, i, 0]
            depth_interval = (
                (near_far[0, i, 1] - near_far[0, i, 0])
                / self.n_depths[-1]
                / self.interval_ratios[-1]
            )

            _gt_depths_real = {}
            if isinstance(gt_depths_real, dict):
                for l in range(3):
                    _gt_depths_real[f"level_{l}"] = gt_depths_real[f"level_{l}"][:, i:i+1]

            neural_vol_output = self.create_neural_volume(
                feats,
                affine_mats,
                affine_mats_inv,
                idx=permuted_idx,
                init_depth_min=init_depth_min,
                depth_interval=depth_interval,
                gt_depths=gt_depths[:, i : i + 1],
                gt_depths_real=_gt_depths_real,
                imgs=imgs,
                view_id=i,
                content_style_feat=content_style_feat,
            )
            v_feat, d_map, d_values, other_vol_output = neural_vol_output
            depth_interval_ls_out = other_vol_output['depth_interval_levels']
            depth_probs_out = other_vol_output['depth_probs']
            if self.use_disocclude:
                diso_confi = other_vol_output['disocclude_confi']
            if self.texFeat or self.texFeat_woUnet:
                _texFeat = other_vol_output['texFeat']
            if self.pDensity:
                pDensity_out = other_vol_output['pDensity']
                weights_mvs_out = other_vol_output['weights_mvs']
            if self.O_label:
                O_label_val_out = other_vol_output['O_label']
            if self.cas_confi:
                depth_confi_out = other_vol_output['depth_confi']
            if self.style3Dfeat:
                style3D_out = other_vol_output['style3D']

            for l in range(3):
                if self.check_feat_mode == '2' or self.separate_occ_feat:
                    feats_vol['A'][f"level_{l}"].append(v_feat['A'][f"level_{l}"])
                    feats_vol['B'][f"level_{l}"].append(v_feat['B'][f"level_{l}"])
                else:
                    feats_vol[f"level_{l}"].append(v_feat[f"level_{l}"])
                depth_map[f"level_{l}"].append(d_map[f"level_{l}"])
                depth_values[f"level_{l}"].append(d_values[f"level_{l}"])
                depth_interval_levels[f"level_{l}"].append(depth_interval_ls_out[f"level_{l}"])
                depth_probs_output[f"level_{l}"].append(depth_probs_out[f"level_{l}"])
                if self.use_disocclude:
                    disocc_confi[f"level_{l}"].append(diso_confi[f"level_{l}"])
                if self.pDensity:
                    pDensity_output[f"level_{l}"].append(pDensity_out[f"level_{l}"])
                    weights_mvs_output[f"level_{l}"].append(weights_mvs_out[f"level_{l}"])
                if self.O_label:
                    O_label_val_output[f"level_{l}"].append(O_label_val_out[f"level_{l}"])
                if self.texFeat or self.texFeat_woUnet:
                    texFeat_outputs[f"level_{l}"].append(_texFeat[f"level_{l}"])
                if self.cas_confi:
                    depth_confi_outputs[f"level_{l}"].append(depth_confi_out[f"level_{l}"])
                if self.style3Dfeat:
                    style3D_outputs[f"level_{l}"].append(style3D_out[f"level_{l}"])

        for l in range(3):
            if self.check_feat_mode == '2' or self.separate_occ_feat:
                feats_vol['A'][f"level_{l}"] = torch.stack(feats_vol['A'][f"level_{l}"], dim=1)
                feats_vol['B'][f"level_{l}"] = torch.stack(feats_vol['B'][f"level_{l}"], dim=1)
            else:
                feats_vol[f"level_{l}"] = torch.stack(feats_vol[f"level_{l}"], dim=1)
            depth_map[f"level_{l}"] = torch.cat(depth_map[f"level_{l}"], dim=1)
            depth_values[f"level_{l}"] = torch.stack(depth_values[f"level_{l}"], dim=1)
            depth_probs_output[f"level_{l}"] = torch.stack(depth_probs_output[f"level_{l}"], dim=1)
            if self.use_disocclude:
                disocc_confi[f"level_{l}"] = torch.stack(disocc_confi[f"level_{l}"], dim=1)
            if self.pDensity:
                pDensity_output[f"level_{l}"] = torch.stack(pDensity_output[f"level_{l}"], dim=1)
                weights_mvs_output[f"level_{l}"] = torch.stack(weights_mvs_output[f"level_{l}"], dim=1)
            if self.O_label:
                O_label_val_output[f"level_{l}"] = torch.stack(O_label_val_output[f"level_{l}"], dim=1)
            if self.texFeat or self.texFeat_woUnet:
                texFeat_outputs[f"level_{l}"] = torch.stack(texFeat_outputs[f"level_{l}"], dim=1)
            if self.cas_confi:
                depth_confi_outputs[f"level_{l}"] = torch.cat(depth_confi_outputs[f"level_{l}"], dim=1)
            if self.style3Dfeat:
                style3D_outputs[f"level_{l}"] = torch.stack(style3D_outputs[f"level_{l}"], dim=1)

        other_output = {}
        other_output['depth_interval_levels'] = depth_interval_levels
        other_output['depth_probs'] = depth_probs_output
        if self.use_disocclude:
            other_output['disocc_confi'] = disocc_confi
        if self.texFeat or self.texFeat_woUnet:
            other_output['texFeat'] = texFeat_outputs
        if self.pDensity:
            other_output['pDensity'] = pDensity_output
            other_output['weights_mvs'] = weights_mvs_output
        if self.O_label:
            other_output['O_label'] = O_label_val_output
        if self.use_featSA:
            other_output['feat_SA'] = feat_SA
        if self.use_global_geoFeat:
            B, V, c, d, h, w = feats_vol["level_2"].shape
            _feat = feats_vol["level_2"].permute(0, 1, 3, 2, 4, 5).reshape(-1, c, h, w)
            _global_geoFeat = self.geo_global_conv(_feat)
            _, c_, h_, w_ = _global_geoFeat.shape
            global_geoFeat = _global_geoFeat.reshape(B, V, d, c_, h_, w_).permute(0, 1, 3, 2, 4, 5)
            other_output['global_geoFeat'] = global_geoFeat
        if self.cas_confi:
            other_output['depth_confi'] = depth_confi_outputs
        if self.style3Dfeat:
            other_output['style3D'] = style3D_outputs
        if self.styleTwoBranch:
            if self.unparallExtract:
                other_output['common_feat'] = special_feats
                other_output['special_feat'] = feats
            else:
                other_output['common_feat'] = feats
                other_output['special_feat'] = special_feats
        
        other_output['feats'] = feats


        return feats_vol, feats_fpn, depth_map, depth_values, other_output


    def output_novel_disocc_func(
        self, imgs, affine_mats, affine_mats_inv, near_far, closest_idxs, gt_depths=None, gt_depths_real=None
    ):
        B, V, _, H, W = imgs.shape ### V = input views + 1(novel)

        ## Feature Pyramid
        feats = self.feature(
            imgs.reshape(B * V, 3, H, W)
        )  # (B*V, 8, H, W), (B*V, 16, H//2, W//2), (B*V, 32, H//4, W//4)
        feats_fpn = feats[f"level_0"].reshape(B, V, *feats[f"level_0"].shape[1:])

        feats_vol = {"level_0": [], "level_1": [], "level_2": []}
        # depth_map = {"level_0": [], "level_1": [], "level_2": []}
        depth_values = {"level_0": [], "level_1": [], "level_2": []}
        disocc_confi = {"level_0": [], "level_1": [], "level_2": []}
        # texFeat_outputs = []
        ## Create cost volumes for target view
        for i in range(V-1, V): # target view
            permuted_idx = torch.tensor(closest_idxs[0, i]).cuda()

            init_depth_min = near_far[0, i, 0]
            depth_interval = (
                (near_far[0, i, 1] - near_far[0, i, 0])
                / self.n_depths[-1]
                / self.interval_ratios[-1]
            )

            _gt_depths_real = {}
            if isinstance(gt_depths_real, dict):
                for l in range(3):
                    _gt_depths_real[f"level_{l}"] = gt_depths_real[f"level_{l}"][:, i:i+1]

            neural_vol_output = self.create_neural_volume(
                feats,
                affine_mats,
                affine_mats_inv,
                idx=permuted_idx,
                init_depth_min=init_depth_min,
                depth_interval=depth_interval,
                gt_depths=gt_depths[:, i : i + 1],
                gt_depths_real=_gt_depths_real,
                imgs=imgs
            )
            v_feat, d_map, d_values, other_vol_output = neural_vol_output
            if self.use_disocclude:
                diso_confi = other_vol_output['disocclude_confi']
            # if self.texFeat:
            #     _texFeat = other_vol_output['texFeat']

            for l in range(3):
                feats_vol[f"level_{l}"].append(v_feat[f"level_{l}"])
                # depth_map[f"level_{l}"].append(d_map[f"level_{l}"])
                depth_values[f"level_{l}"].append(d_values[f"level_{l}"])
                if self.use_disocclude:
                    disocc_confi[f"level_{l}"].append(diso_confi[f"level_{l}"])
            # if self.texFeat:
            #     texFeat_outputs.append(_texFeat)

        for l in range(3):
            feats_vol[f"level_{l}"] = torch.stack(feats_vol[f"level_{l}"], dim=1)
            # depth_map[f"level_{l}"] = torch.cat(depth_map[f"level_{l}"], dim=1)
            depth_values[f"level_{l}"] = torch.stack(depth_values[f"level_{l}"], dim=1)
            if self.use_disocclude:
                disocc_confi[f"level_{l}"] = torch.stack(disocc_confi[f"level_{l}"], dim=1)
        # if self.texFeat:
        #     texFeat_outputs = torch.stack(texFeat_outputs, dim=1)

        # other_output = {}
        # if self.use_disocclude:
        #     other_output['disocc_confi'] = disocc_confi
        # if self.texFeat:
        #     other_output['texFeat'] = texFeat_outputs

        return disocc_confi, depth_values, feats_vol
