# GeoNeRF is a generalizable NeRF model that renders novel views
# without requiring per-scene optimization. This software is the 
# implementation of the paper "GeoNeRF: Generalizing NeRF with 
# Geometry Priors" by Mohammad Mahdi Johari, Yann Lepoittevin,
# and Francois Fleuret.

# Copyright (c) 2022 ams International AG

# This file is part of GeoNeRF.
# GeoNeRF is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3 as
# published by the Free Software Foundation.

# GeoNeRF is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with GeoNeRF. If not, see <http://www.gnu.org/licenses/>.

# This file incorporates work covered by the following copyright and  
# permission notice:

    # Copyright 2020 Google LLC
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     https://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

import math

def weights_init(m):
    if isinstance(m, nn.Linear):
        stdv = 1.0 / math.sqrt(m.weight.size(1))
        m.weight.data.uniform_(-stdv, stdv)
        if m.bias is not None:
            m.bias.data.uniform_(stdv, stdv)

def masked_softmax(x, mask, **kwargs):
    x_masked = x.masked_fill(mask == 0, -float("inf"))

    return torch.softmax(x_masked, **kwargs)

def weighted_mean(values, weights, dim, keepdim=True):
    weights = weights.masked_fill(weights.sum(dim=dim, keepdims=True) == 0, 1)
    mean = torch.sum(weights*values, dim=dim, keepdim=keepdim) / torch.sum(weights, dim=dim, keepdim=keepdim)

    return mean

def weighted_var_and_ang(values, weights, dim, keepdim=True):
    average = weighted_mean(values, weights, dim, keepdim)
    variance = weighted_mean((values-average)**2, weights, dim, keepdim)

    return variance, average

## Auto-encoder network
class ConvAutoEncoder(nn.Module):
    def __init__(self, num_ch, S):
        super(ConvAutoEncoder, self).__init__()

        # Encoder
        self.conv1 = nn.Sequential(
            nn.Conv1d(num_ch, num_ch * 2, 3, stride=1, padding=1),
            nn.LayerNorm(S, elementwise_affine=False),
            nn.ELU(alpha=1.0, inplace=True),
            nn.MaxPool1d(2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(num_ch * 2, num_ch * 4, 3, stride=1, padding=1),
            nn.LayerNorm(S // 2, elementwise_affine=False),
            nn.ELU(alpha=1.0, inplace=True),
            nn.MaxPool1d(2),
        )
        self.conv3 = nn.Sequential(
            nn.Conv1d(num_ch * 4, num_ch * 4, 3, stride=1, padding=1),
            nn.LayerNorm(S // 4, elementwise_affine=False),
            nn.ELU(alpha=1.0, inplace=True),
            nn.MaxPool1d(2),
        )

        # Decoder
        self.t_conv1 = nn.Sequential(
            nn.ConvTranspose1d(num_ch * 4, num_ch * 4, 4, stride=2, padding=1),
            nn.LayerNorm(S // 4, elementwise_affine=False),
            nn.ELU(alpha=1.0, inplace=True),
        )
        self.t_conv2 = nn.Sequential(
            nn.ConvTranspose1d(num_ch * 8, num_ch * 2, 4, stride=2, padding=1),
            nn.LayerNorm(S // 2, elementwise_affine=False),
            nn.ELU(alpha=1.0, inplace=True),
        )
        self.t_conv3 = nn.Sequential(
            nn.ConvTranspose1d(num_ch * 4, num_ch, 4, stride=2, padding=1),
            nn.LayerNorm(S, elementwise_affine=False),
            nn.ELU(alpha=1.0, inplace=True),
        )
        # Output
        self.conv_out = nn.Sequential(
            nn.Conv1d(num_ch * 2, num_ch, 3, stride=1, padding=1),
            nn.LayerNorm(S, elementwise_affine=False),
            nn.ELU(alpha=1.0, inplace=True),
        )

    def forward(self, x):
        input = x
        x = self.conv1(x)
        conv1_out = x
        x = self.conv2(x)
        conv2_out = x
        x = self.conv3(x)

        x = self.t_conv1(x)
        x = self.t_conv2(torch.cat([x, conv2_out], dim=1))
        x = self.t_conv3(torch.cat([x, conv1_out], dim=1))

        x = self.conv_out(torch.cat([x, input], dim=1))

        return x


class ScaledDotProductAttention(nn.Module):
    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature

    def forward(self, q, k, v, mask=None):

        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = masked_softmax(attn, mask, dim=-1)
        else:
            attn = F.softmax(attn, dim=-1)

        output = torch.matmul(attn, v)

        return output, attn


class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)  # position-wise
        self.w_2 = nn.Linear(d_hid, d_in)  # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x += residual

        x = self.layer_norm(x)

        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k**0.5)

        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.transpose(1, 2).unsqueeze(1)  # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
        q = self.fc(q)
        q += residual

        q = self.layer_norm(q)

        return q, attn


class EncoderLayer(nn.Module):
    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None, query_input=None):
        if query_input == None:
            enc_output, enc_slf_attn = self.slf_attn(
                enc_input, enc_input, enc_input, mask=slf_attn_mask
            )
        else:
            enc_output, enc_slf_attn = self.slf_attn(
                query_input, enc_input, enc_input, mask=slf_attn_mask
            )
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn


class Renderer(nn.Module):
    def __init__(self, nb_samples_per_ray, weightedMeanVar=False, gene_mask="None", check_feat_mode="None"):
        super(Renderer, self).__init__()

        self.gene_mask = gene_mask
        self.check_feat_mode = check_feat_mode
        self.weightedMeanVar = weightedMeanVar

        self.dim = 32
        if check_feat_mode != "None":
            self.attn_token_gen = nn.Linear(8 + 1 + 8, self.dim)
        else:
            self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks, viewdirs_novel, for_mask=None, input_phi=None, z=None, alpha=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        if self.check_feat_mode != "None":
            v_feat = v_feat[...,-8:]
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        if self.weightedMeanVar:
            disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)

        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

        ## Mean and variance of 2D features provide view-independent tokens
        if self.weightedMeanVar:
            var_mean = weighted_var_and_ang(s_feat, disocc_confi[...,0:1], dim=1, keepdim=True)
        else:
            var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1)

        if self.gene_mask != "None":
            if self.gene_mask == "interval":
                pts_d = for_mask['pts_d'].reshape(-1,1,1)
                pts_d_gt = for_mask['pts_d_gt'].unsqueeze(-1).repeat(1,S).reshape(-1,1,1)
                need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1) #(bs*n_sample, 1, 1)
                need_gene_mask[(pts_d-pts_d_gt > 0.2)*(pts_d_gt != 0)] = 0
                need_gene_mask[(pts_d-pts_d_gt < -0.2)*(pts_d_gt != 0)] = 0
                need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                ray_pts_cnt = torch.zeros_like(pts_d_gt) # to cnt set to 1; else 0
                ray_pts_cnt[(pts_d-pts_d_gt <= 0.2)*(pts_d-pts_d_gt >= 0)*(pts_d_gt != 0)] = 1
                ray_pts_cnt[(pts_d-pts_d_gt >= -0.2)*(pts_d-pts_d_gt < 0)*(pts_d_gt != 0)] = 1
            
            elif self.gene_mask == "one_pt":
                need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1)
                need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                ray_pts_cnt = torch.ones_like(need_gene_mask)[...,0] # to cnt set to 1; else 0

            outputs = torch.cat([need_gene_mask, ray_pts_cnt.reshape(N,S,1)], -1)

            return outputs
            
            

        ## Adding a new channel to mask for var_mean
        vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )
        ## If a point is not visible by any source view, force its masks to enabled
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

        ## Taking occ_masks into account, but remembering if there were any visibility before that
        mask_cloned = vis_mask.clone()
        vis_mask[:, :-1] *= occ_masks
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
        masks = vis_mask * mask_cloned

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with an Auto-Encoder and MLP
        sigma_tokens = tokens[:, -1:]
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens)
        sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

        sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)

        rgb = (colors * rgb_w).sum(1)
        
        # black = torch.zeros_like(rgb)
        # white = torch.ones_like(rgb)
        # black_gene = black * need_gene_mask
        # _white = white * (need_gene_mask-1)*(-1)
        # rgb = black_gene + _white
        
        ## save_masks
        # sigma = torch.sum(masks[:,0:-1],dim=1)
        ##

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)

        return outputs



class Renderer_v1(nn.Module):
    def __init__(self, nb_samples_per_ray):
        '''modify geometry part'''
        super(Renderer_v1, self).__init__()

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        self.density_attn_token_gen = nn.Linear(8*3 + 1 + 8, self.dim)
        # density_dim = 8*3+1
        # density_d_k = density_dim // n_head
        # density_d_v = density_dim // n_head
        self.density_attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)
        self.weight_sigma = nn.Linear(3, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim=8*level(3)+8+3+1=36)
        assert feat.shape[-1] == 39
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)

        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

        ## Mean and variance of 2D features provide view-independent tokens
        var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1) # (-1, nb_view+1, dim)
        
        # for sigma
        sigma_tokens = F.elu(
            self.density_attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        
        sigma_vis_mask = vis_mask.clone()
        
        ## Adding a new channel to mask for var_mean
        vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )

        def vis_mask_func(vis_mask, is_sigma=False):
            ## If a point is not visible by any source view, force its masks to enabled
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

            ## Taking occ_masks into account, but remembering if there were any visibility before that
            mask_cloned = vis_mask.clone()
            if not is_sigma:
                vis_mask[:, :-1] *= occ_masks
            # else:
            #     vis_mask *= occ_masks
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
            masks = vis_mask * mask_cloned

            if is_sigma:
                masks = masks.masked_fill(masks.sum(dim=1, keepdims=True) == 0, 1) # may occur all mask values = 0 for a point x

            return masks
        
        masks = vis_mask_func(vis_mask)
        sigma_masks = vis_mask_func(sigma_vis_mask, True)

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with a MHA and MLP
        for layer in self.density_attn_layers:
            sigma_tokens, _ = layer(sigma_tokens, sigma_masks)

        # assert sigma_tokens.shape == [N * S, V, self.dim]
        # sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
        # sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        # sigma_per_views = torch.relu(self.sigma_fc3(sigma_tokens))
        # weight_per_views = torch.softmax(self.weight_sigma(disocc_confi), dim=1)
        # sigma = torch.sum(weight_per_views * sigma_per_views, dim=1)

        # + autoencoder
        weight_per_views = torch.softmax(self.weight_sigma(disocc_confi), dim=1)
        sigma_tokens = torch.sum(weight_per_views * sigma_tokens, dim=1)
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens).transpose(1, 2).reshape(N * S, self.dim)
        sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        sigma = torch.relu(self.sigma_fc3(sigma_tokens))

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)

        rgb = (colors * rgb_w).sum(1)

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)

        return outputs



class Renderer_v1_mlp(nn.Module):
    def __init__(self, nb_samples_per_ray):
        '''modify geometry part (use MLP instead of MHA)'''
        super(Renderer_v1_mlp, self).__init__()

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        # self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_bias = nn.Linear(8*3 + 1 + 8, self.dim)
        self.sigma_layers = nn.ModuleList([nn.Linear(3, self.dim)] + [nn.Linear(self.dim, self.dim) for i in range(6)])
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim=8*level(3)+8+3+1=36)
        assert feat.shape[-1] == 39
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)

        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

        ## Mean and variance of 2D features provide view-independent tokens
        var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1) # (-1, nb_view+1, dim)
        
        # for sigma
        # sigma_tokens = F.elu(
        #     self.density_attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        # )
        
        sigma_vis_mask = vis_mask.clone()
        
        ## Adding a new channel to mask for var_mean
        _vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )

        def vis_mask_func(vis_mask, is_sigma=False):
            ## If a point is not visible by any source view, force its masks to enabled
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

            ## Taking occ_masks into account, but remembering if there were any visibility before that
            mask_cloned = vis_mask.clone()
            if not is_sigma:
                vis_mask[:, :-1] *= occ_masks
            # else:
            #     vis_mask *= occ_masks
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
            masks = vis_mask * mask_cloned

            if is_sigma:
                masks = masks.masked_fill(masks.sum(dim=1, keepdims=True) == 0, 1) # may occur all mask values = 0 for a point x

            return masks
        
        masks = vis_mask_func(_vis_mask)
        sigma_masks = vis_mask_func(sigma_vis_mask, True)

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with weighted MLP
        h = disocc_confi
        bias = self.sigma_bias(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        for sigma_layer in self.sigma_layers:
            h = sigma_layer(h) * bias
            h = F.elu(h)
        sigma_tokens = F.elu(self.sigma_fc2(h))
        sigma_per_views = torch.relu(self.sigma_fc3(sigma_tokens))
        sigma = torch.mean(sigma_per_views, dim=1)

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)

        rgb = (colors * rgb_w).sum(1)

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)

        return outputs



class Renderer_v2(nn.Module):
    def __init__(self, nb_samples_per_ray):
        '''modify appearance part'''
        super(Renderer_v2, self).__init__()

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_bias = nn.Linear(8+3+1, self.dim)
        self.rgb_layers = nn.ModuleList([nn.Linear(9+3+24, self.dim)] + [nn.Linear(self.dim, self.dim) for i in range(6)])

        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 3)

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        assert feat.shape[-1] == (39+8)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)
        tex_feat = feat[..., 24 + 8 + 3 + 1 + 3 : 24 + 8 + 3 + 1 + 3 + 8]

        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

        ## Mean and variance of 2D features provide view-independent tokens
        var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1)

        ## Adding a new channel to mask for var_mean
        _vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )

        def vis_mask_func(vis_mask, is_sigma=False):
            ## If a point is not visible by any source view, force its masks to enabled
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

            ## Taking occ_masks into account, but remembering if there were any visibility before that
            mask_cloned = vis_mask.clone()
            vis_mask[:, :-1] *= occ_masks
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
            masks = vis_mask * mask_cloned

            return masks

        masks = vis_mask_func(_vis_mask)

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with an Auto-Encoder and MLP
        sigma_tokens = tokens[:, -1:]
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens)
        sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

        sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

        ## Concatenating positional encodings and predicting RGB weights
        h = torch.cat([viewdirs, disocc_confi, v_feat], dim=-1)
        bias = self.rgb_bias(torch.cat([tex_feat, colors, vis_mask], dim=-1))
        for rgb_layer in self.rgb_layers:
            h = rgb_layer(h) * bias
            h = F.elu(h)
        rgb_tokens = F.elu(self.rgb_fc2(h))
        rgb_per_views = torch.sigmoid(self.rgb_fc3(rgb_tokens))
        rgb = torch.mean(rgb_per_views, dim=1)

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)

        return outputs



class Renderer_density_v1(nn.Module):
    def __init__(self, nb_samples_per_ray):
        super(Renderer_density_v1, self).__init__()

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)
        self.sigma_token_gen = nn.Linear(24 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        assert feat.shape[-1] == 39
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)

        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

        ## Mean and variance of 2D features provide view-independent tokens
        var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1)

        # sigma_token
        _v_feat = torch.zeros((v_feat.shape[0],v_feat.shape[-1]),device=v_feat.device)
        disocc_confi_softmax = F.softmax(disocc_confi, dim=1)
        for v in range(V):
            _v_feat[:,0:8] += disocc_confi_softmax[:,v,0:1]*v_feat[:,v,0:8]
            _v_feat[:,8:16] += disocc_confi_softmax[:,v,1:2]*v_feat[:,v,8:16]
            _v_feat[:,16:24] += disocc_confi_softmax[:,v,2:]*v_feat[:,v,16:24]
        
        _s_feat = torch.zeros((s_feat.shape[0],s_feat.shape[-1]),device=s_feat.device)
        for v in range(V):
            _s_feat[:,:] += disocc_confi_softmax[:,v,0:1]*s_feat[:,v,:]

        sigma_tokens = F.elu(
            self.sigma_token_gen(torch.cat([_v_feat, _s_feat], dim=-1))
        ).unsqueeze(1)

        ## Adding a new channel to mask for var_mean
        vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )
        ## If a point is not visible by any source view, force its masks to enabled
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

        ## Taking occ_masks into account, but remembering if there were any visibility before that
        mask_cloned = vis_mask.clone()
        vis_mask[:, :-1] *= occ_masks
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
        masks = vis_mask * mask_cloned

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with an Auto-Encoder and MLP
        # sigma_tokens = tokens[:, -1:]
        # +autoencoder
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens)
        sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

        sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)

        rgb = (colors * rgb_w).sum(1)
        # print(colors.max(),rgb_w.max(),sigma.max())

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)

        return outputs



class Renderer_RGBresidual(nn.Module):
    def __init__(self, nb_samples_per_ray):
        super(Renderer_RGBresidual, self).__init__()

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_nerf_fc1 = nn.Linear(3*3+3*24 + 9*3, self.dim * 2) #TODO nb_view*24
        self.rgb_nerf_fc2 = nn.Linear(self.dim * 2, self.dim)
        self.rgb_nerf_fc3 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_nerf_fc4 = nn.Linear(self.dim // 2, 3)
        self.rgb_final = nn.Linear(3, 3)
        # self.nerf_w_fc = nn.Linear(3*2, 1) #TODO nb_view*2

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs_diff, feat, occ_masks, viewdirs_novel):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)
        tex_feat = feat[..., 24 + 8 + 3 + 1 + 3 : 24 + 8 + 3 + 1 + 3 + 8*3] # tex_feat (level 0, 1, 2)

        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs_diff = viewdirs_diff.view(-1, *viewdirs_diff.shape[2:])
        viewdirs_novel = viewdirs_novel.view(-1, viewdirs_novel.shape[-1])

        ## Mean and variance of 2D features provide view-independent tokens
        var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1) # (bs*n_sample, nb_view+1, 32)

        ## Adding a new channel to mask for var_mean
        vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )
        ## If a point is not visible by any source view, force its masks to enabled
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

        ## Taking occ_masks into account, but remembering if there were any visibility before that
        mask_cloned = vis_mask.clone()
        vis_mask[:, :-1] *= occ_masks
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
        masks = vis_mask * mask_cloned

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with an Auto-Encoder and MLP
        sigma_tokens = tokens[:, -1:]
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens)
        sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

        sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

        ## Output RGB using texfeat
        tex_feat = tex_feat.reshape(-1,V*24)
        rgb_nerf_tokens = torch.cat([colors.reshape(-1,V*3), tex_feat, viewdirs_novel], dim=-1)
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc1(rgb_nerf_tokens))
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc2(rgb_nerf_tokens))
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc3(rgb_nerf_tokens))
        rgb_nerf = torch.sigmoid(self.rgb_nerf_fc4(rgb_nerf_tokens))
        # rgb_nerf = rgb_nerf.unsqueeze(1).repeat(1,V,1)

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([tokens[:, :-1], viewdirs_diff], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        _rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(_rgb_w, masks[:, :-1], dim=1)

        # nerf_w_input = torch.cat([_rgb_w,masks[:, :-1]], dim=-1).reshape(_rgb_w.shape[0],-1)
        # nerf_w = torch.sigmoid(self.nerf_w_fc(nerf_w_input)) #(N*S, 1)
        # rgb = (colors * rgb_w).sum(1)
        # rgb = (rgb_nerf + (colors-rgb_nerf) * rgb_w).sum(1)
        # rgb = torch.sigmoid(self.rgb_final(rgb))
        # rgb = (1-nerf_w) * ((colors * rgb_w).sum(1)) + nerf_w * rgb_nerf
        disocc_confi = disocc_confi.reshape(disocc_confi.shape[0],-1)
        D_mean = torch.mean(disocc_confi,dim=-1,keepdim=True)
        rgb = D_mean * ((colors * rgb_w).sum(1)) + (1-D_mean) * rgb_nerf

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)

        return outputs



class Renderer_geneRGBsigma(nn.Module):
    def __init__(self, nb_samples_per_ray, use_attention_3d=False, output_gene=False):
        super(Renderer_geneRGBsigma, self).__init__()

        self.use_attention_3d = use_attention_3d
        self.output_gene = output_gene
        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        if use_attention_3d:
            self.feat_3d_attention_layers = FeatureSelfAttention_3d()

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        nb_view = 3
        self.sigma_guess_fc1 = nn.Linear(self.dim + nb_view*32, self.dim * 2)
        self.sigma_guess_fc2 = nn.Linear(self.dim * 2, self.dim // 2)
        self.sigma_guess_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        rgb_nerf_input_dim = 3+24+16+9 if not self.use_attention_3d else 3+24+16+9+16
        self.rgb_nerf_fc1 = nn.Linear(nb_view*rgb_nerf_input_dim, self.dim * 2) #TODO nb_view*24
        self.rgb_nerf_fc2 = nn.Linear(self.dim * 2, self.dim)
        self.rgb_nerf_fc3 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_nerf_fc4 = nn.Linear(self.dim // 2, 3)
        self.rgb_final = nn.Linear(3, 3)
        # self.nerf_w_fc = nn.Linear(3*2, 1) #TODO nb_view*2

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs_diff, feat, occ_masks, viewdirs_novel, for_att_3d=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        # disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)
        # tex_feat = feat[..., 24 + 8 + 3 + 1 + 3 : 24 + 8 + 3 + 1 + 3 + 8*3] # tex_feat (level 0, 1, 2)
        # feat_SA = feat[..., 24 + 8 + 3 + 1 + 3 + 8*3 : 24 + 8 + 3 + 1 + 3 + 8*3 + 16] # feat_SA
        # global_geoFeat = feat[..., 24 + 8 + 3 + 1 + 3 + 8*3 + 16 : 24 + 8 + 3 + 1 + 3 + 8*3 + 16 + 32] # global_geoFeat
        
        tex_feat = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 8*3] # tex_feat (level 0, 1, 2)
        feat_SA = feat[..., 24 + 8 + 3 + 1 + 8*3 : 24 + 8 + 3 + 1 + 8*3 + 16] # feat_SA
        global_geoFeat = feat[..., 24 + 8 + 3 + 1 + 8*3 + 16 : 24 + 8 + 3 + 1 + 8*3 + 16 + 32] # global_geoFeat

        if self.use_attention_3d:
            tex_feat_att_3d = self.feat_3d_attention_layers(all_feat=for_att_3d['tex_feat_2'], pts_ndc=for_att_3d['rays_pts_ndc_2'])
            tex_feat_att_3d = tex_feat_att_3d.reshape(N*S,V,-1)
        
        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs_diff = viewdirs_diff.view(-1, *viewdirs_diff.shape[2:])
        viewdirs_novel = viewdirs_novel.view(-1, viewdirs_novel.shape[-1])

        ## Mean and variance of 2D features provide view-independent tokens
        var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1) # (bs*n_sample, nb_view+1, 32)

        need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
        need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
        need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1) #(bs*n_sample, 1, 1)
        need_gene_mask = need_gene_mask.squeeze(-1)

        ## Adding a new channel to mask for var_mean
        vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )
        ## If a point is not visible by any source view, force its masks to enabled
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

        ## Taking occ_masks into account, but remembering if there were any visibility before that
        mask_cloned = vis_mask.clone()
        vis_mask[:, :-1] *= occ_masks
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
        masks = vis_mask * mask_cloned

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with an Auto-Encoder and MLP
        sigma_tokens = tokens[:, -1:]
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens)
        _sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

        sigma_tokens = F.elu(self.sigma_fc1(_sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        _sigma_known = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

        global_geoFeat = global_geoFeat.reshape(-1, 1, V*32)
        sigma_guess_tokens = torch.cat([_sigma_tokens, global_geoFeat], dim=-1)
        sigma_guess_tokens = F.elu(self.sigma_guess_fc1(sigma_guess_tokens))
        sigma_guess_tokens = F.elu(self.sigma_guess_fc2(sigma_guess_tokens))
        _sigma_guess = torch.relu(self.sigma_guess_fc3(sigma_guess_tokens[:, 0]))

        sigma_guess = _sigma_guess * need_gene_mask
        sigma_known = _sigma_known * ((need_gene_mask-1)*(-1))
        sigma = sigma_known + sigma_guess
        # sigma = _sigma_guess

        ## Output RGB using texfeat
        tex_feat = tex_feat.reshape(-1,V*24)
        feat_SA = feat_SA.reshape(-1,V*16)
        rgb_nerf_tokens = torch.cat([colors.reshape(-1,V*3), tex_feat, feat_SA, viewdirs_novel], dim=-1)
        if self.use_attention_3d:
            rgb_nerf_tokens = torch.cat([rgb_nerf_tokens, tex_feat_att_3d.reshape(-1,V*16)], dim=-1)
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc1(rgb_nerf_tokens))
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc2(rgb_nerf_tokens))
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc3(rgb_nerf_tokens))
        _rgb_nerf = torch.sigmoid(self.rgb_nerf_fc4(rgb_nerf_tokens))

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([tokens[:, :-1], viewdirs_diff], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        _rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(_rgb_w, masks[:, :-1], dim=1)
        _rgb_ori = (colors * rgb_w).sum(1)

        rgb_nerf = _rgb_nerf * need_gene_mask
        rgb_ori = _rgb_ori * ((need_gene_mask-1)*(-1))
        rgb = rgb_ori + rgb_nerf
        # rgb = _rgb_nerf

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)

        if self.output_gene:
            outputs_gene = torch.cat([_rgb_nerf, _sigma_guess], -1)
            outputs_gene = outputs_gene.reshape(N, S, -1)
            return outputs, outputs_gene
        else:
            return outputs

from inplace_abn import InPlaceABN
class FeatureSelfAttention_3d(nn.Module):
    def __init__(self, norm_act=InPlaceABN):
        super(FeatureSelfAttention_3d, self).__init__()

        from model.geo_reasoner import ConvBnReLU3D
        self.conv1 = nn.Sequential(
            # ConvBnReLU3D(8, 16, kernel_size=7, stride=4, pad=3, norm_act=norm_act), # //4
            # ConvBnReLU3D(16, 32, kernel_size=7, stride=4, pad=3, norm_act=norm_act), # //4
            nn.Conv3d(8, 16, kernel_size=7, stride=4, padding=3),
            nn.Conv3d(16, 32, kernel_size=7, stride=4, padding=3)
        )
        # self.embed_dim = 27
        # self.query_fc = nn.Linear(32+self.embed_dim, 32, bias=False)
        self.query_fc = nn.Linear(32+3, 32, bias=False)

        self.downsample = nn.Upsample(scale_factor=0.75, mode='trilinear')

        dim = 32
        d_inner = dim
        n_head = 4
        d_k = dim // n_head
        d_v = dim // n_head
        self.att = EncoderLayer(dim, d_inner, n_head, d_k, d_v)

        # to reduce channel size of the outputs from FPN
        self.smooth = nn.Linear(32, 16)

    def forward(self, all_feat, pts_ndc):
        # x: (B, V, 8, D, H, W)
        B, V, C, D, h, w = all_feat.shape
        all_feat = self.conv1(all_feat.reshape(-1,C,D,h,w))  # (B*V, 32, D//16, H//16, W//16)

        B_V, C, D, h, w = all_feat.shape
        # if D*h*w > 400:
        #     all_feat = self.downsample(all_feat)
        #     B_V, C, D, h, w = all_feat.shape
        from utils.utils import interpolate_3D
        from utils.rendering import get_embedder
        texFeat_x = []
        for i in range(V):
            texFeat_x.append(interpolate_3D(all_feat.unsqueeze(0)[:, i], pts_ndc[:, :, i]))
        bs, nb_samples, C = texFeat_x[0].shape
        
        att_feat_3d = []
        for v in range(V):
            _all_feat = all_feat[v,...].permute(1,2,3,0).reshape(B, -1, C) #(1, d*h*w, 32)
            pts_ndc_v = pts_ndc[:,:,v]
            # pts_ndc_embed = get_embedder(4)(pts_ndc_v)
            # query_input = torch.cat((texFeat_x[v],pts_ndc_embed),dim=-1).reshape(1, -1, C+self.embed_dim)
            query_input = torch.cat((texFeat_x[v],pts_ndc_v),dim=-1).reshape(1, -1, C+3)
            query_input = self.query_fc(query_input)
            att_feat_3d_v = self.att(_all_feat, query_input=query_input)[0].reshape(bs, nb_samples, C)
            att_feat_3d.append(self.smooth(att_feat_3d_v).unsqueeze(0))

        att_feat_3d = torch.cat(att_feat_3d,dim=0).permute(1,2,0,3)
        
        return att_feat_3d

class Renderer_geneRGBsigma_dist(nn.Module):
    def __init__(self, nb_samples_per_ray, min_var=0.01, use_attention_3d=False, sample=False, modify=-1, test=False):
        super(Renderer_geneRGBsigma_dist, self).__init__()

        self.test = test
        self.modify = modify
        self.use_attention_3d = use_attention_3d
        self.sample = sample
        self.min_var = min_var
        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)
                
        if use_attention_3d:
            self.feat_3d_attention_layers = FeatureSelfAttention_3d()

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        nb_view = 3
        self.sigma_guess_fc1 = nn.Linear(self.dim + nb_view*32, self.dim * 2)
        self.sigma_guess_fc2 = nn.Linear(self.dim * 2, self.dim // 2)
        self.sigma_guess_fc3 = nn.Linear(self.dim // 2, 1)

        self.var_fc2 = nn.Linear(self.dim * 2 + nb_view*9, self.dim // 2)
        self.var_fc3 = nn.Linear(self.dim // 2, 3) # r,g,b has its own variance
        self.softplus = nn.Softplus()

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        rgb_nerf_input_dim = 3+24+16+9 if not self.use_attention_3d else 3+24+16+9+16
        self.rgb_nerf_fc1 = nn.Linear(nb_view*rgb_nerf_input_dim, self.dim * 2) #TODO nb_view*()
        self.rgb_nerf_fc2 = nn.Linear(self.dim * 2, self.dim)
        self.rgb_nerf_fc3 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_nerf_fc4 = nn.Linear(self.dim // 2, 3)
        self.rgb_final = nn.Linear(3, 3)
        # self.nerf_w_fc = nn.Linear(3*2, 1) #TODO nb_view*2

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs_diff, feat, occ_masks, viewdirs_novel, for_att_3d=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
        # disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)
        # tex_feat = feat[..., 24 + 8 + 3 + 1 + 3 : 24 + 8 + 3 + 1 + 3 + 8*3] # tex_feat (level 0, 1, 2)
        # feat_SA = feat[..., 24 + 8 + 3 + 1 + 3 + 8*3 : 24 + 8 + 3 + 1 + 3 + 8*3 + 16] # feat_SA
        # global_geoFeat = feat[..., 24 + 8 + 3 + 1 + 3 + 8*3 + 16 : 24 + 8 + 3 + 1 + 3 + 8*3 + 16 + 32] # global_geoFeat
        
        tex_feat = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 8*3] # tex_feat (level 0, 1, 2)
        feat_SA = feat[..., 24 + 8 + 3 + 1 + 8*3 : 24 + 8 + 3 + 1 + 8*3 + 16] # feat_SA
        global_geoFeat = feat[..., 24 + 8 + 3 + 1 + 8*3 + 16 : 24 + 8 + 3 + 1 + 8*3 + 16 + 32] # global_geoFeat

        if self.use_attention_3d:
            tex_feat_att_3d = self.feat_3d_attention_layers(all_feat=for_att_3d['tex_feat_2'], pts_ndc=for_att_3d['rays_pts_ndc_2'])
            tex_feat_att_3d = tex_feat_att_3d.reshape(N*S,V,-1)
            
        occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        viewdirs_diff = viewdirs_diff.view(-1, *viewdirs_diff.shape[2:])
        viewdirs_novel = viewdirs_novel.view(-1, viewdirs_novel.shape[-1])

        ## Mean and variance of 2D features provide view-independent tokens
        var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = torch.cat(var_mean, dim=-1)
        var_mean = F.elu(self.var_mean_fc1(var_mean))
        var_mean = F.elu(self.var_mean_fc2(var_mean))

        ## Converting the input features to tokens (view-dependent) before self-attention
        tokens = F.elu(
            self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
        )
        tokens = torch.cat([tokens, var_mean], dim=1) # (bs*n_sample, nb_view+1, 32)

        need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
        need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
        need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1) #(bs*n_sample, 1, 1)
        need_gene_mask = need_gene_mask.squeeze(-1)

        ## Adding a new channel to mask for var_mean
        vis_mask = torch.cat(
            [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
        )
        ## If a point is not visible by any source view, force its masks to enabled
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

        ## Taking occ_masks into account, but remembering if there were any visibility before that
        mask_cloned = vis_mask.clone()
        vis_mask[:, :-1] *= occ_masks
        vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
        masks = vis_mask * mask_cloned

        ## Performing self-attention
        for layer in self.attn_layers:
            tokens, _ = layer(tokens, masks)

        ## Predicting sigma with an Auto-Encoder and MLP
        sigma_tokens = tokens[:, -1:]
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens)
        _sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

        sigma_tokens = F.elu(self.sigma_fc1(_sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        _sigma_known = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

        global_geoFeat = global_geoFeat.reshape(-1, 1, V*32)
        sigma_guess_tokens = torch.cat([_sigma_tokens, global_geoFeat], dim=-1)
        sigma_guess_tokens_1 = self.sigma_guess_fc1(sigma_guess_tokens)
        # sigma_guess
        sigma_guess_tokens = F.elu(sigma_guess_tokens_1)
        sigma_guess_tokens = F.elu(self.sigma_guess_fc2(sigma_guess_tokens))
        _sigma_guess = torch.relu(self.sigma_guess_fc3(sigma_guess_tokens[:, 0]))
        # variance ## depend on geometry and dir_diff
        var_tokens = F.elu(sigma_guess_tokens_1)
        var_tokens = torch.cat([var_tokens, viewdirs_diff.reshape(-1,1,V*9)], dim=-1)
        var_tokens = F.elu(self.var_fc2(var_tokens))
        _var = self.var_fc3(var_tokens[:, 0])
        dist_var = self.min_var + self.softplus(_var)

        sigma_guess = _sigma_guess * need_gene_mask
        sigma_known = _sigma_known * ((need_gene_mask-1)*(-1))
        sigma = sigma_known + sigma_guess

        ## Output RGB using texfeat
        tex_feat = tex_feat.reshape(-1,V*24)
        feat_SA = feat_SA.reshape(-1,V*16)
        rgb_nerf_tokens = torch.cat([colors.reshape(-1,V*3), tex_feat, feat_SA, viewdirs_novel], dim=-1)
        if self.use_attention_3d:
            rgb_nerf_tokens = torch.cat([rgb_nerf_tokens, tex_feat_att_3d.reshape(-1,V*16)], dim=-1)
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc1(rgb_nerf_tokens))
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc2(rgb_nerf_tokens))
        rgb_nerf_tokens = F.elu(self.rgb_nerf_fc3(rgb_nerf_tokens))
        if self.modify == 4:
            if not self.test:
                _rgb_nerf = self.rgb_nerf_fc4(rgb_nerf_tokens)
            else:
                _rgb_nerf = torch.sigmoid(self.rgb_nerf_fc4(rgb_nerf_tokens))
        else:
            _rgb_nerf = torch.sigmoid(self.rgb_nerf_fc4(rgb_nerf_tokens))
        
        if self.sample:
            _rgb_nerf = torch.sigmoid(torch.normal(mean=_rgb_nerf,std=(dist_var/100)**(0.5)))#.clip(0,1)

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([tokens[:, :-1], viewdirs_diff], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        _rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(_rgb_w, masks[:, :-1], dim=1)
        _rgb_ori = (colors * rgb_w).sum(1)

        if self.modify in [3,4] and (not self.test):
            rgb = _rgb_ori
        else:
            rgb_nerf = _rgb_nerf * need_gene_mask
            rgb_ori = _rgb_ori * ((need_gene_mask-1)*(-1))
            rgb = rgb_ori + rgb_nerf

        outputs = torch.cat([rgb, sigma], -1)
        outputs = outputs.reshape(N, S, -1)
        dist_mean = _rgb_nerf.reshape(N, S, -1)
        dist_var = dist_var.reshape(N, S, -1)

        return outputs, {'dist_var':dist_var, 'dist_mean':dist_mean}


class Renderer_mvslike(nn.Module):
    def __init__(self, nb_samples_per_ray, gene_mask="None",
                 D=8, W=128, input_ch=8, output_ch=4, input_ch_feat=8, tex_2Dto3D=False, use_tex_bias=False, weighted_rgb=False):
        super(Renderer_mvslike, self).__init__()

        self.D = D # depth of linear layers
        self.W = W # hidden layer dims
        self.input_ch = input_ch # input mlp channel dims.
        input_ch_views = 11+8+3 if use_tex_bias else 8+3 # texture feature
        # self.skips = skips # skip connection
        self.in_ch, self.in_ch_views, self.in_ch_feat = input_ch, input_ch_views, input_ch_feat
        self.tex_2Dto3D = tex_2Dto3D
        self.use_tex_bias = use_tex_bias
        self.weighted_rgb = weighted_rgb

        if tex_2Dto3D:
            tex_w_dim = 8+9 # geometry feature + r(dir_diff)
            self.dim = 16
            d_inner = self.dim
            n_head = 4
            d_k = self.dim // n_head
            d_v = self.dim // n_head
            num_layers = 2

            self.tex_attn_token_gen = nn.Linear(tex_w_dim, self.dim)
            self.tex_weight_attn_layers = nn.ModuleList([
                    EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                    for i in range(num_layers)
                ]
            )
            self.tex_weight_fc1 = nn.Linear(self.dim, self.dim//2)
            self.tex_weight_fc2 = nn.Linear(self.dim//2, self.dim//4)
            self.tex_weight_fc3 = nn.Linear(self.dim//4, 1)

            self.tex_weight_final = nn.Linear(3, 1) # 3 levels weight to final weight

        if self.use_tex_bias:
            self.tex_weight_linear = nn.Linear(2, 1)

        self.pts_linears = nn.ModuleList(
            [nn.Linear(self.in_ch, W, bias=True)] + [nn.Linear(W, W, bias=True) for i in range(D-1)])
        for l in range(3):
            pts_bias_l = nn.Linear(input_ch_feat, W)
            setattr(self, f"pts_bias_{l}", pts_bias_l)
            if use_tex_bias:
                tex_pts_bias_l = nn.Linear(input_ch_views, W)
                setattr(self, f"tex_pts_bias_{l}", tex_pts_bias_l)
        
        if self.weighted_rgb:
            self.dim = 32
            self.rgb_weight_attn_token_gen = nn.Linear(24+1+8+1, self.dim) # geometry feature(l=0~2)(also need to choose coarser level) + D + 2d feature + vis_mask  (mask: mask3d(l=2))
            # self.rgb_weight_attn_token_gen = nn.Linear(24+8+1, self.dim) # geometry feature(l=0~2)(also need to choose coarser level) + D + 2d feature + vis_mask  (mask: mask3d(l=2))
            d_inner = self.dim
            n_head = 4
            d_k = self.dim // n_head
            d_v = self.dim // n_head
            num_layers = 4
            self.rgb_weight_attn_layers = nn.ModuleList(
                [
                    EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                    for i in range(num_layers)
                ]
            )
            self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
            self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
            self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        else:
            if self.use_tex_bias:
                self.views_linears = nn.ModuleList([nn.Linear(W+input_ch_views, W), nn.Linear(W, W), nn.Linear(W, W//2)])
            else:
                self.views_linears = nn.ModuleList([nn.Linear(W+input_ch_views, W//2)])

            self.feature_linear = nn.Linear(W, W)

        self.alpha_linear = nn.Linear(W, 1)
        self.rgb_linear = nn.Linear(W//2, 3)

        # ## Initialization
        self.pts_linears.apply(weights_init)
        if not self.weighted_rgb:
            self.views_linears.apply(weights_init)
            self.feature_linear.apply(weights_init)
        self.alpha_linear.apply(weights_init)
        self.rgb_linear.apply(weights_init)

    def forward(self, _viewdirs, feat, occ_masks, viewdirs_novel, for_mask=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        # (N*S, V, ?)
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach() # 2d mask: 1=visible for a input view in 2D space
        disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)
        texFeat = feat[..., 24 + 8 + 3 + 1 + 3: 24 + 8 + 3 + 1 + 3 + 11*3] # texture feature (level 0, 1, 2)
        mask3d = feat[..., 24 + 8 + 3 + 1 + 3 + 11*3: 24 + 8 + 3 + 1 + 3 + 11*3 + 1*3] # mask3d (level 0, 1, 2): 1=visible for a input view in 3D space
        
        # occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        if isinstance(_viewdirs, dict):
            viewdirs_PE = _viewdirs["PE"].view(-1, *_viewdirs["PE"].shape[2:]) # (N*S, V, 1): cos difference between inputs dir and novel dir; [tex_2Dto3D] (N*S, V, 9): +positional encoding
            viewdirs_cos = _viewdirs["cos"].view(-1, *_viewdirs["cos"].shape[2:]) # (N*S, V, 1): cos difference between inputs dir and novel dir; [tex_2Dto3D] (N*S, V, 9): +positional encoding
        else:
            viewdirs = _viewdirs.view(-1, *_viewdirs.shape[2:]) # (N*S, V, 1): cos difference between inputs dir and novel dir; [tex_2Dto3D] (N*S, V, 9): +positional encoding

        ### aggregate feature
        invis_pts_range = torch.zeros_like(disocc_confi[:,0:1,:]) # (N*S, 1, 3)
        invis_pts_range = invis_pts_range.masked_fill(mask3d.sum(dim=1, keepdims=True) == 0, 1) # (N*S, 1, level=3): 1=out of 3d space range for ALL input views
        invis_pts = invis_pts_range.masked_fill(disocc_confi.sum(dim=1, keepdims=True) == 0, 1).repeat(1,V,1).bool() # (N*S, V, level=3): (*) 1=D=0 for ALL input views & out of 3d space range for ALL input views
        # mask3d_convert_invis_into_1 = mask3d.masked_fill(mask3d.sum(dim=1, keepdims=True) == 0, 1) # force mask of points that are out of range for all inputs ENABLE (avoid masked_softmax can't work)
        D_mask3d = disocc_confi.masked_fill(mask3d == 0, 0) # 把out of range的D變成0 (因為out of range的是用boarded interpolate) 1=visible
        D_mask3d_convert = D_mask3d.masked_fill(D_mask3d.sum(dim=1, keepdims=True) == 0, 1) # (N*S, V, level=3): force mask of points that are out of range for all inputs ENABLE (avoid masked_softmax can't work)
        
        ## geometry feature
        geo_weight = masked_softmax(disocc_confi, D_mask3d_convert, dim=1) # (N*S, V, level=3)
        # geo_weight[invis_pts] = torch.zeros_like(geo_weight[invis_pts]) # turn those points(*) weight into 0

        geo_feat_3d = {}
        for l in range(3):
            geo_feat_3d[f"level_{l}"] = torch.sum(geo_weight[...,l:l+1]*v_feat[...,l*8:(l+1)*8], dim=1) # (N*S, 8)
        
        ## texture feature
        if self.tex_2Dto3D:
            tex_weight_token, tex_weight_output = {}, {}
            viewdirs = viewdirs_PE if isinstance(_viewdirs, dict) else viewdirs
            for l in range(3):
                tex_weight_input_l = torch.cat((v_feat[...,l*8:(l+1)*8],viewdirs),dim=-1) # (N*S, V, 8+9)
                tex_weight_token[f"level_{l}"] = F.elu(self.tex_attn_token_gen(tex_weight_input_l)) # (N*S, V, 16)
            
                for layer in self.tex_weight_attn_layers:
                    tex_weight_token[f"level_{l}"], _ = layer(tex_weight_token[f"level_{l}"], D_mask3d_convert[...,l:l+1])
                
                tex_weight_token[f"level_{l}"] = F.elu(self.tex_weight_fc1(tex_weight_token[f"level_{l}"]))
                tex_weight_token[f"level_{l}"] = F.elu(self.tex_weight_fc2(tex_weight_token[f"level_{l}"]))
                tex_weight_output[f"level_{l}"] = self.tex_weight_fc3(tex_weight_token[f"level_{l}"])
            
            # some points are out of the range of finer level -> convert that weight into coarser ones
            convert_tex_weight = {}
            for l in range(3):
                convert_tex_weight[f"level_{l}"] = tex_weight_output[f"level_{l}"].clone()
                if l == 2: 
                    continue
                else:
                    mask3d_l = (~mask3d.bool())[...,l].unsqueeze(-1) # 1=invisible (out of range)
                    convert_tex_weight[f"level_{l}"][mask3d_l==1] = tex_weight_output[f"level_{l+1}"][mask3d_l==1].clone()
                    if l == 0:
                        mask3d_l = (~mask3d.bool())[...,l+1].unsqueeze(-1)
                        convert_tex_weight[f"level_{l}"][mask3d_l==1] = tex_weight_output[f"level_{l+2}"][mask3d_l==1].clone()

            # 3 levels weight to final weight
            convert_tex_weight_l_all = torch.cat((convert_tex_weight["level_0"],convert_tex_weight["level_1"],convert_tex_weight["level_2"]),dim=-1) # (N*S, V, level=3)
            _tex_w = self.tex_weight_final(convert_tex_weight_l_all) # (N*S, V, 1)

            tex_w = masked_softmax(_tex_w, D_mask3d_convert[...,2:], dim=1)

            # weight * 2d tex feature/img
            tex_feat_2d = torch.sum(tex_w*s_feat, dim=1) # (N*S, 8)
            color_2d = torch.sum(tex_w*colors, dim=1) # (N*S, 3)
        else:
            if not self.weighted_rgb:
                # 2d texture feature
                # TODO: cur D use level 2, maybe can use level 0 after convert front and behind points into 1,0, respectively
                # invis_pts_2d = torch.zeros_like(vis_mask[:,0:1,:]) # (N*S, 1, 1)
                # invis_pts_2d = invis_pts_2d.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1) # (N*S, 1, 1): 1=out of 2d space range for ALL input views
                # invis_pts_2d = invis_pts_2d.masked_fill(disocc_confi[...,2:].sum(dim=1, keepdims=True) == 0, 1).repeat(1,V,1).bool() # (N*S, V, level=1): (*) 1=D=0 for ALL input views & out of 3d space range for ALL input views
                vis_mask_convert_invis_into_1 = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1) # force mask of points(*) that are out of range for all inputs ENABLE (avoid masked_softmax can't work)
                
                D_l2_mask2d = D_mask3d[...,2:]
                viewdirs = viewdirs_cos if isinstance(_viewdirs, dict) else viewdirs
                _tex_weight_2d = self.tex_weight_linear(torch.cat((D_l2_mask2d, viewdirs),dim=-1))
                tex_weight_2d = masked_softmax(_tex_weight_2d, vis_mask_convert_invis_into_1, dim=1) # (N*S, V, 1)
                # tex_weight_2d[invis_pts_2d] = torch.zeros_like(tex_weight_2d[invis_pts_2d]) # turn those points(*) weight into 0

                tex_feat_2d = torch.sum(tex_weight_2d*s_feat, dim=1) # (N*S, 8)
                color_2d = torch.sum(tex_weight_2d*colors, dim=1) # (N*S, 3)
        
        if self.use_tex_bias:
            # 3d texture feature
            viewdirs = viewdirs_cos if isinstance(_viewdirs, dict) else viewdirs
            _tex_weight = torch.zeros_like(disocc_confi)
            for l in range(3):
                _tex_weight[...,l:l+1] = self.tex_weight_linear(torch.cat((D_mask3d[...,l:l+1], viewdirs),dim=-1))
            tex_weight = masked_softmax(_tex_weight, D_mask3d_convert, dim=1) # (N*S, V, level=3)
            # tex_weight[invis_pts] = torch.zeros_like(tex_weight[invis_pts]) # turn those points(*) weight into 0

            tex_feat_3d = {}
            for l in range(3):
                tex_feat_3d[f"level_{l}"] = torch.sum(tex_weight[...,l:l+1]*texFeat[...,l*11:(l+1)*11], dim=1) # (N*S, 11)

        ## predict sigma and sigma by MLP
        # first, generate bias
        bias = []
        convert_geo_feat_3d = {}
        what_level = torch.ones_like(geo_weight[:,0,:]) # (N*S, level=3): store actual level index
        # some points are out of the range of finer level -> convert that feature into coarser ones
        for l in range(3):
            what_level[:, l] *= 1
            convert_geo_feat_3d[f"level_{l}"] = geo_feat_3d[f"level_{l}"].clone()
            if l == 2: 
                continue
            else:
                invis_pts_range_l = invis_pts_range[:,0,l].unsqueeze(-1).repeat(1,self.in_ch_feat)
                convert_geo_feat_3d[f"level_{l}"][invis_pts_range_l==1] = geo_feat_3d[f"level_{l+1}"][invis_pts_range_l==1].clone()
                what_level[:, l][invis_pts_range[:,0,l]==1] += 1
                if l == 0:
                    invis_pts_range_l = invis_pts_range[:,0,l+1].unsqueeze(-1).repeat(1,self.in_ch_feat)
                    convert_geo_feat_3d[f"level_{l}"][invis_pts_range_l==1] = geo_feat_3d[f"level_{l+2}"][invis_pts_range_l==1].clone()
                    what_level[:, l][invis_pts_range[:,0,l+1]==1] += 1

        what_level_o = what_level.unsqueeze(1).repeat(1,self.W,1) # (N*S, W, level=3)

        pts_bias_val = []
        for l in range(3):
            bias_input = geo_feat_3d[f"level_{l}"]
            pts_bias_l = getattr(self, f"pts_bias_{l}")
            pts_bias_val.append(pts_bias_l(bias_input)) # (N*S, W)
            
        level_choose = [2, 2, 1, 1, 1, 0, 0, 0]
        for i in range(len(self.pts_linears)):
            if i in [1, 3, 4, 6, 7]:
                bias.append(bias[-1])
                continue

            tmp_l = level_choose[i]
            bias_output = torch.zeros_like(pts_bias_val[0]) # (N*S, W)
            for l in range(3): # actual level
                bias_output[what_level_o[...,tmp_l]==l] = pts_bias_val[l][what_level_o[...,tmp_l]==l]

            bias.append(bias_output)
        
        if self.use_tex_bias:
            tex_bias = []
            convert_tex_feat_3d = {}
            what_level = torch.ones_like(geo_weight[:,0,:]) # (N*S, level=3): store actual level index
            # some points are out of the range of finer level -> convert that feature into coarser ones
            for l in range(3):
                what_level[:, l] *= 1
                convert_tex_feat_3d[f"level_{l}"] = tex_feat_3d[f"level_{l}"].clone()
                if l == 2: 
                    continue
                else:
                    invis_pts_range_l = invis_pts_range[:,0,l].unsqueeze(-1).repeat(1,self.in_ch_feat+3)
                    convert_tex_feat_3d[f"level_{l}"][invis_pts_range_l==1] = tex_feat_3d[f"level_{l+1}"][invis_pts_range_l==1].clone()
                    what_level[:, l][invis_pts_range[:,0,l]==1] += 1
                    if l == 0:
                        invis_pts_range_l = invis_pts_range[:,0,l+1].unsqueeze(-1).repeat(1,self.in_ch_feat+3) ## -skipNAN version use invis_pts[...,l] -> wrong
                        convert_tex_feat_3d[f"level_{l}"][invis_pts_range_l==1] = tex_feat_3d[f"level_{l+2}"][invis_pts_range_l==1].clone()
                        what_level[:, l][invis_pts_range[:,0,l+1]==1] += 1

            what_level_o = what_level.unsqueeze(1).repeat(1,self.W,1) # (N*S, W, level=3)

            tex_pts_bias_val = []
            for l in range(3):
                bias_input = torch.cat((tex_feat_3d[f"level_{l}"],tex_feat_2d,color_2d),dim=-1)
                tex_pts_bias_l = getattr(self, f"tex_pts_bias_{l}")
                tex_pts_bias_val.append(tex_pts_bias_l(bias_input)) # (N*S, W)
            
            level_choose = [1, 0]
            for i in range(len(self.views_linears)-1):
                tmp_l = level_choose[i]
                bias_output = torch.zeros_like(tex_pts_bias_val[0]) # (N*S, W)
                for l in range(3): # actual level
                    bias_output[what_level_o[...,tmp_l]==l] = tex_pts_bias_val[l][what_level_o[...,tmp_l]==l]

                tex_bias.append(bias_output)

        
        # begin to go through MLP
        h = geo_feat_3d["level_2"].clone()
        for i, l in enumerate(self.pts_linears):
            h = self.pts_linears[i](h) * bias[i]
            h = F.relu(h)

        sigma = torch.relu(self.alpha_linear(h))
        
        if self.weighted_rgb:
            # some points are out of the range of finer level -> convert that v_feat into coarser ones
            convert_v_feat = {}
            for l in range(3):
                convert_v_feat[f"level_{l}"] = v_feat[...,l*8:(l+1)*8].clone()
                if l == 2: 
                    continue
                else:
                    mask3d_l = (~mask3d.bool())[...,l].unsqueeze(-1).repeat(1,1,8) # mask3d: 1=visible
                    convert_v_feat[f"level_{l}"][mask3d_l==1] = v_feat[...,(l+1)*8:(l+2)*8][mask3d_l==1].clone()
                    if l == 0:
                        mask3d_l = (~mask3d.bool())[...,l+1].unsqueeze(-1).repeat(1,1,8)
                        convert_v_feat[f"level_{l}"][mask3d_l==1] = v_feat[...,(l+2)*8:][mask3d_l==1].clone()

            # 3 levels feat to final feat
            convert_v_feat_l_all = torch.cat((convert_v_feat["level_0"],convert_v_feat["level_1"],convert_v_feat["level_2"]),dim=-1) # (N*S, V, 8*level=24)

            rgb_w_tokens = F.elu(self.rgb_weight_attn_token_gen(torch.cat([convert_v_feat_l_all, disocc_confi[...,2:], s_feat, vis_mask], dim=-1))) # (N*S, V, 24+1+8+1)
            # rgb_w_tokens = F.elu(self.rgb_weight_attn_token_gen(torch.cat([convert_v_feat_l_all, s_feat, vis_mask], dim=-1))) # (N*S, V, 24+1+8+1)
            
            ## Performing self-attention
            for layer in self.rgb_weight_attn_layers:
                rgb_w_tokens, _ = layer(rgb_w_tokens, D_mask3d_convert[...,2:])

            ## Concatenating positional encodings and predicting RGB weights
            viewdirs = viewdirs_PE if isinstance(_viewdirs, dict) else viewdirs
            rgb_tokens = torch.cat([rgb_w_tokens, viewdirs], dim=-1)
            rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
            rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
            rgb_w = self.rgb_fc3(rgb_tokens)
            rgb_w = masked_softmax(rgb_w, D_mask3d_convert[...,2:], dim=1)

            rgb = (colors * rgb_w).sum(1)

        else:
            feature = self.feature_linear(h)
            if self.use_tex_bias:
                h = torch.cat([feature, tex_feat_3d["level_2"], tex_feat_2d, color_2d], -1)
                for i, l in enumerate(self.views_linears):
                    if i != len(self.views_linears)-1:
                        h = self.views_linears[i](h) * tex_bias[i]
                    else:
                        h = self.views_linears[i](h)
                    h = F.relu(h)
                
            else:
                h = torch.cat([feature, tex_feat_2d, color_2d], -1)
                for i, l in enumerate(self.views_linears):
                    h = self.views_linears[i](h)
                    h = F.relu(h)
                    
            rgb = torch.sigmoid(self.rgb_linear(h))

        ## visibility B
        # use 3 level (bc purpose is to update D)
        B_all = torch.mean(D_mask3d, dim=1) # (N*S, level=3)

        ## force invis point(include out of range & D=0) density and rgb to 0 # img_mse use level 2!
        outputs, B = {}, {}
        for l in range(3):
            sigma = sigma.masked_fill(invis_pts[:,0,l:l+1] == 1, 0)
            rgb = rgb.masked_fill(invis_pts[:,0,l:l+1].repeat(1,3) == 1, 0)

            outputs[f"level_{l}"] = torch.cat([rgb, sigma], -1)
            outputs[f"level_{l}"] = outputs[f"level_{l}"].reshape(N, S, -1)

            B[f"level_{l}"] = B_all[...,l].reshape(N, S, -1)

        return outputs, B



class Renderer_mvslike_v2(nn.Module):
    def __init__(self, nb_samples_per_ray, gene_mask="None",
                 D=8, W=128, input_ch=8, output_ch=4, input_ch_feat=8, tex_2Dto3D=False, use_tex_bias=False, weighted_rgb=False):
        super(Renderer_mvslike_v2, self).__init__()

        self.D = D # depth of linear layers
        self.W = W # hidden layer dims
        self.input_ch = input_ch # input mlp channel dims.
        input_ch_views = 11+8+3 if use_tex_bias else 8+3 # texture feature
        # self.skips = skips # skip connection
        self.in_ch, self.in_ch_views, self.in_ch_feat = input_ch, input_ch_views, input_ch_feat
        self.tex_2Dto3D = tex_2Dto3D
        self.use_tex_bias = use_tex_bias
        self.weighted_rgb = weighted_rgb

        ## Self-Attention Settings
        self.dim = 16
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False
        
        self.attn_token_gen = nn.Linear(8 + 1 + 8, self.dim)

        ## 3 levels feature aggregate
        self.levels_agg_fc1 = nn.Linear(self.dim*3, self.dim)
        self.levels_agg_fc2 = nn.Linear(self.dim, self.dim)

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, _viewdirs, feat, occ_masks, viewdirs_novel, for_mask=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        # (N*S, V, ?)
        v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
        s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
        colors = feat[..., 24 + 8 : 24 + 8 + 3]
        vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach() # 2d mask: 1=visible for a input view in 2D space
        disocc_confi = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 3] # disocc_confi (level 0, 1, 2)
        texFeat = feat[..., 24 + 8 + 3 + 1 + 3: 24 + 8 + 3 + 1 + 3 + 11*3] # texture feature (level 0, 1, 2)
        mask3d = feat[..., 24 + 8 + 3 + 1 + 3 + 11*3: 24 + 8 + 3 + 1 + 3 + 11*3 + 1*3] # mask3d (level 0, 1, 2): 1=visible for a input view in 3D space
        
        # occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
        if isinstance(_viewdirs, dict):
            viewdirs_PE = _viewdirs["PE"].view(-1, *_viewdirs["PE"].shape[2:]) # (N*S, V, 1): cos difference between inputs dir and novel dir; [tex_2Dto3D] (N*S, V, 9): +positional encoding
            viewdirs_cos = _viewdirs["cos"].view(-1, *_viewdirs["cos"].shape[2:]) # (N*S, V, 1): cos difference between inputs dir and novel dir; [tex_2Dto3D] (N*S, V, 9): +positional encoding
        else:
            viewdirs = _viewdirs.view(-1, *_viewdirs.shape[2:]) # (N*S, V, 1): cos difference between inputs dir and novel dir; [tex_2Dto3D] (N*S, V, 9): +positional encoding

        ### aggregate feature
        invis_pts_range = torch.zeros_like(disocc_confi[:,0:1,:]) # (N*S, 1, 3)
        invis_pts_range = invis_pts_range.masked_fill(mask3d.sum(dim=1, keepdims=True) == 0, 1) # (N*S, 1, level=3): 1=out of 3d space range for ALL input views
        invis_pts = invis_pts_range.masked_fill(disocc_confi.sum(dim=1, keepdims=True) == 0, 1).repeat(1,V,1).bool() # (N*S, V, level=3): (*) 1=D=0 for ALL input views & out of 3d space range for ALL input views
        # mask3d_convert_invis_into_1 = mask3d.masked_fill(mask3d.sum(dim=1, keepdims=True) == 0, 1) # force mask of points that are out of range for all inputs ENABLE (avoid masked_softmax can't work)
        D_mask3d = disocc_confi.masked_fill(mask3d == 0, 0) # (N*S, V, level=3)
        D_mask3d = D_mask3d.masked_fill(D_mask3d.sum(dim=1, keepdims=True) == 0, 1) # force mask of points that are out of range for all inputs ENABLE (avoid masked_softmax can't work)

        ## geometry feature
        geo_weight = masked_softmax(disocc_confi, D_mask3d, dim=1) # (N*S, V, level=3) #TODO:之前用錯了
        # geo_weight[invis_pts] = torch.zeros_like(geo_weight[invis_pts]) # turn those points(*) weight into 0

        geo_feat_3d = {}
        for l in range(3):
            geo_feat_3d[f"level_{l}"] = torch.sum(geo_weight[...,l:l+1]*v_feat[...,l*8:(l+1)*8], dim=1, keepdim=True) # (N*S, 1, 8)
        
        var, _ = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
        var_mean = {}
        for l in range(3):
            var_mean_l = torch.cat((var,geo_feat_3d[f"level_{l}"]), dim=-1)
            var_mean_l = F.elu(self.var_mean_fc1(var_mean_l))
            var_mean[f"level_{l}"] = F.elu(self.var_mean_fc2(var_mean_l))

        tokens = {}
        for l in range(3):
            tokens_l = F.elu(self.attn_token_gen(torch.cat([v_feat[...,l*8:(l+1)*8], vis_mask, s_feat], dim=-1))) # (N*S, V, 16)
            tokens[f"level_{l}"] = torch.cat([tokens_l, var_mean[f"level_{l}"]], dim=1) # (N*S, V+1, 16)
        
        # mask: D_mask3d (cat) var_mean_mask (all ones)
        D_mask3d = torch.cat((D_mask3d, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, 3)),dim=1)

        ## Performing self-attention
        for l in range(3):
            for layer in self.attn_layers:
                tokens[f"level_{l}"], _ = layer(tokens[f"level_{l}"], D_mask3d[...,l:l+1])
        
        ## some points are out of the range of finer level -> convert that weight into coarser ones
        convert_tokens = {}
        for l in range(3):
            convert_tokens[f"level_{l}"] = tokens[f"level_{l}"].clone()
            if l == 2: 
                continue
            else: #TODO mask3d用錯了!!
                mask3d_l = (~mask3d.bool())[...,l].unsqueeze(-1).repeat(1,1,self.dim) # mask3d: 1=visible # invis_pts_range: global(all views) 1=invisible (for all views)
                invis_pts_range_l = invis_pts_range[...,l].unsqueeze(-1).repeat(1,1,self.dim)
                _invis_pts = torch.cat((mask3d_l,invis_pts_range_l),dim=1) # 1=invisible
                convert_tokens[f"level_{l}"][_invis_pts==1] = tokens[f"level_{l+1}"][_invis_pts==1].clone()
                if l == 0:
                    mask3d_l = (~mask3d.bool())[...,l+1].unsqueeze(-1).repeat(1,1,self.dim) # mask3d: 1=visible # invis_pts_range: global(all views) 1=invisible (for all views)
                    invis_pts_range_l = invis_pts_range[...,l+1].unsqueeze(-1).repeat(1,1,self.dim)
                    _invis_pts = torch.cat((mask3d_l,invis_pts_range_l),dim=1) # 1=invisible
                    convert_tokens[f"level_{l}"][_invis_pts==1] = tokens[f"level_{l+2}"][_invis_pts==1].clone()
        
        ## 3 levels feature aggreate
        final_tokens = torch.cat([convert_tokens[f"level_{l}"] for l in range(3)], dim=-1)
        final_tokens = F.elu(self.levels_agg_fc1(final_tokens))
        final_tokens = F.elu(self.levels_agg_fc2(final_tokens))

        ## Predicting sigma with an Auto-Encoder and MLP
        sigma_tokens = final_tokens[:, -1:]
        sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
        sigma_tokens = self.auto_enc(sigma_tokens)
        sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

        sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
        sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
        sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

        ## Concatenating positional encodings and predicting RGB weights
        rgb_tokens = torch.cat([final_tokens[:, :-1], viewdirs], dim=-1)
        rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
        rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
        rgb_w = self.rgb_fc3(rgb_tokens)
        rgb_w = masked_softmax(rgb_w, D_mask3d[:, :-1], dim=1)

        rgb = (colors * rgb_w).sum(1)


        ## visibility B
        # use 3 level (bc purpose is to update D)
        B_all = torch.mean(D_mask3d, dim=1) # (N*S, level=3)

        ## force invis point(include out of range & D=0) density and rgb to 0 # img_mse use level 2!
        outputs, B = {}, {}
        for l in range(3): # !! all are level 2
            sigma = sigma.masked_fill(invis_pts[:,0,2:] == 1, 0)
            rgb = rgb.masked_fill(invis_pts[:,0,2:].repeat(1,3) == 1, 0)

            outputs[f"level_{l}"] = torch.cat([rgb, sigma], -1)
            outputs[f"level_{l}"] = outputs[f"level_{l}"].reshape(N, S, -1)

            # directly use level 2
            B[f"level_{l}"] = B_all[...,l].reshape(N, S, -1)

        return outputs, B

class style2phi_network(nn.Module):
    def __init__(self):
        super(style2phi_network, self).__init__()

        self.linears = nn.ModuleList(
                        [nn.Linear(8, 4)] + [nn.Linear(4, 4) for i in range(3)]
                    )
        self.final_linear = nn.Linear(4, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):
        for fc in self.linears:
            x = self.relu(fc(x))
        x = self.sigmoid(self.final_linear(x)) # [0,1]
        return x * (2*math.pi)

class style2remain_network(nn.Module):
    def __init__(self):
        super(style2remain_network, self).__init__()

        self.linears = nn.ModuleList(
                        [nn.Linear(8, 8)] + [nn.Linear(8, 7)] + [nn.Linear(7, 7) for i in range(2)]
                    )
        self.lrelu = nn.LeakyReLU()
    
    def forward(self,x):
        for fc in self.linears:
            x = self.lrelu(fc(x))
        return x

class style2weather_network(nn.Module):
    def __init__(self):
        super(style2weather_network, self).__init__()

        self.linears = nn.ModuleList(
                        [nn.Linear(8, 8)] + [nn.Linear(8, 8) for i in range(3)]
                    )
        self.relu = nn.ReLU()
    
    def forward(self,x):
        for fc in self.linears:
            x = self.relu(fc(x))
        return x

class phi2code_network(nn.Module):
    def __init__(self, final_actv='lrelu', input_dim=1):
        super(phi2code_network, self).__init__()
        self.final_actv = final_actv
        self.linears = nn.ModuleList(
                        [nn.Linear(input_dim, 4)] + [nn.Linear(4, 4) for i in range(2)] + [nn.Linear(4, 8)]
                    )
        self.lrelu = nn.LeakyReLU()
        
        if final_actv == 'sigmoid':
            self.finalActv = nn.Sigmoid()
        elif final_actv == 'lrelu':
            self.finalActv = nn.LeakyReLU()
    
    def forward(self,x):
        for i, fc in enumerate(self.linears):
            if i != len(self.linears)-1:
                x = self.lrelu(fc(x))
            else:
                x = self.finalActv(fc(x))
        
        return x

class RendererStyle(nn.Module):
    def __init__(self, nb_samples_per_ray, gene_mask="None", n_domain=5, catDomain=False, catStyle=False, styleMLP_cls=False, styleLast=False, adainNormalize=False, style3Dfeat=False, timephi=False, weatherEmbedding=False, weatherEncode=False, weatherEncodeCls=False, add_z=False,
                 zInputStyle=False, delta_t=False, delta_t_1x1=False):
        super(RendererStyle, self).__init__()

        self.gene_mask = gene_mask
        self.catDomain = catDomain
        self.catStyle = catStyle
        self.styleMLP_cls = styleMLP_cls
        self.styleLast = styleLast
        self.adainNormalize = adainNormalize
        self.style3Dfeat = style3Dfeat
        self.timephi = timephi
        self.weatherEmbedding = weatherEmbedding
        self.weatherEncode = weatherEncode
        self.weatherEncodeCls = weatherEncodeCls
        self.add_z = add_z

        self.zInputStyle = zInputStyle
        self.delta_t = delta_t
        self.delta_t_1x1 = delta_t_1x1

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        # in_ch = [64, 128, 256]
        # self.in_ch = in_ch
        
        # hidden_layers = []
        # hidden_in, hidden_out = in_ch[2], 256
        # for i in range(4):
        #     hidden_layers.append(nn.Linear(hidden_in, hidden_out))
        #     if i==3: break
        #     if i!=0:
        #         hidden_in = hidden_out + 8
        #     else:
        #         hidden_in = hidden_out + 8 + n_domain
        #     hidden_out = hidden_in//2

        # self.generate_hidden_layers = nn.ModuleList(hidden_layers)
        # self.style_rgb_fc = nn.Linear(hidden_out, 3)

        if self.timephi:
            self.style2phi = style2phi_network()
            self.cosphi_gamma = phi2code_network()
            self.sinphi_beta = phi2code_network()
            if self.weatherEmbedding:
                # self.weather_embed = nn.Embedding(5,8)
                self.weather_embed = nn.Embedding(4,8)
                self.cosphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
                self.sinphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
            if self.weatherEncode:
                self.weather_encode_fcs = style2weather_network()
                self.cosphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
                self.sinphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
            if self.weatherEncodeCls:
                assert self.weatherEncode == True
                self.weatherEncoding_classifier = nn.Sequential(nn.Linear(8,4),nn.ReLU(),nn.Linear(4,4),nn.ReLU(),nn.Linear(4,4))
        
        if self.zInputStyle:
            self.style2remain = style2remain_network()
            self.z_dim = 7

        if self.add_z:
            z_cos_in = 8 + self.z_dim
            self.z_cos_mapping = nn.Sequential(
                nn.Linear(z_cos_in, z_cos_in), nn.ReLU(), nn.Linear(z_cos_in, z_cos_in), nn.ReLU()
            )
            self.z_sin_mapping = nn.Sequential(
                nn.Linear(z_cos_in, z_cos_in), nn.ReLU(), nn.Linear(z_cos_in, z_cos_in), nn.ReLU()
            )

        if self.style3Dfeat:
            # in_ch = [32, 64, 128]
            in_ch = [8, 16, 32]
            self.in_ch = in_ch
            self.hidden_layer_1 = nn.Linear(in_ch[2], in_ch[1])
            self.hidden_layer_0 = nn.Linear(in_ch[1], in_ch[0])

            self.style3D_var_1 = nn.Linear(in_ch[1], in_ch[1])
            self.style3D_mean_1 = nn.Linear(in_ch[1], in_ch[1])
            self.style3D_var_0 = nn.Linear(in_ch[0], in_ch[0])
            self.style3D_mean_0 = nn.Linear(in_ch[0], in_ch[0])

            self.style_rgb_fc1 = nn.Linear(in_ch[0], in_ch[0]//2)
            self.style_rgb_fc2= nn.Linear(in_ch[0]//2, 3)
        else:
            if self.styleMLP_cls:
                in_ch = [64, 128, 256]
                self.in_ch = in_ch

                self.style_var_fc_0, self.style_var_fc_1, self.style_var_fc_2 = [], [], []
                self.style_mean_fc_0, self.style_mean_fc_1, self.style_mean_fc_2 = [], [], []
                self.content_r_0, self.content_r_1, self.content_r_2 = [], [], []
                self.content_b_0, self.content_b_1, self.content_b_2 = [], [], []
                self.generate_hidden_layers = []
                self.generate_layers = []
                self.style_rgb_fc1 = []
                self.style_rgb_fc2 = []

                for d in range(n_domain):
                    for l in range(3):
                        if self.catDomain:
                            style_in = 8 + n_domain
                        else:
                            if l == 2:
                                style_in = 8 + n_domain
                            else:
                                style_in = 8
                        style_var_fc_l = nn.Linear(style_in, in_ch[l])
                        style_mean_fc_l = nn.Linear(style_in, in_ch[l])
                        content_r_l = nn.Linear(in_ch[l], in_ch[l])
                        content_b_l = nn.Linear(in_ch[l], in_ch[l])
                        if l == 0:
                            self.style_var_fc_0.append(style_var_fc_l)
                            self.style_mean_fc_0.append(style_mean_fc_l)
                            self.content_r_0.append(content_r_l)
                            self.content_b_0.append(content_b_l)
                        elif l == 1:
                            self.style_var_fc_1.append(style_var_fc_l)
                            self.style_mean_fc_1.append(style_mean_fc_l)
                            self.content_r_1.append(content_r_l)
                            self.content_b_1.append(content_b_l)
                        elif l == 2:
                            self.style_var_fc_2.append(style_var_fc_l)
                            self.style_mean_fc_2.append(style_mean_fc_l)
                            self.content_r_2.append(content_r_l)
                            self.content_b_2.append(content_b_l)
                    
                    self.generate_hidden_layers.append(nn.ModuleList(
                        [ nn.Linear(in_ch[i], in_ch[i]) for i in reversed(range(3)) ]
                    ))
                    if self.catStyle:
                        self.generate_layers_0 = nn.ModuleList(
                            [ nn.Linear(in_ch[i+1]+8, in_ch[i+1]) for i in reversed(range(2)) ]
                        )
                    self.generate_layers.append(nn.ModuleList(
                        [ nn.Linear(in_ch[i+1], in_ch[i]) for i in reversed(range(2)) ]
                    ))
                    # self.style_rgb_fc1.append(nn.Linear(in_ch[0], in_ch[0]//2))
                    # self.style_rgb_fc2.append(nn.Linear(in_ch[0]//2, 3))

                self.style_var_fc_0, self.style_var_fc_1, self.style_var_fc_2 = nn.ModuleList(self.style_var_fc_0), nn.ModuleList(self.style_var_fc_1), nn.ModuleList(self.style_var_fc_2)
                self.style_mean_fc_0, self.style_mean_fc_1, self.style_mean_fc_2 = nn.ModuleList(self.style_mean_fc_0), nn.ModuleList(self.style_mean_fc_1), nn.ModuleList(self.style_mean_fc_2)
                self.content_r_0, self.content_r_1, self.content_r_2 = nn.ModuleList(self.content_r_0), nn.ModuleList(self.content_r_1), nn.ModuleList(self.content_r_2)
                self.content_b_0, self.content_b_1, self.content_b_2 = nn.ModuleList(self.content_b_0), nn.ModuleList(self.content_b_1), nn.ModuleList(self.content_b_2)
                self.generate_hidden_layers = nn.ModuleList(self.generate_hidden_layers)
                self.generate_layers = nn.ModuleList(self.generate_layers)
                # self.style_rgb_fc1 = nn.ModuleList(self.style_rgb_fc1)
                # self.style_rgb_fc2 = nn.ModuleList(self.style_rgb_fc2)
                self.style_rgb_fc1 = nn.Linear(in_ch[0], in_ch[0]//2)
                self.style_rgb_fc2= nn.Linear(in_ch[0]//2, 3)

            elif self.styleLast:
                in_ch = [64, 128, 256]
                self.in_ch = in_ch
                for l in range(2):
                    content_r_l = nn.Linear(in_ch[l], in_ch[l])
                    content_b_l = nn.Linear(in_ch[l], in_ch[l])
                    setattr(self, f"content_r_{l}", content_r_l)
                    setattr(self, f"content_b_{l}", content_b_l)
                
                self.generate_layers = nn.ModuleList(
                    [ nn.Linear(in_ch[i+1], in_ch[i]) for i in reversed(range(2)) ]
                )

                self.content2style = nn.Linear(in_ch[0], 8)
                self.style_r = nn.Linear(8, 8)
                self.style_b = nn.Linear(8, 8)

                self.style_rgb_fc1 = nn.Linear(8, 3)

            else:
                in_ch = [64, 128, 256]
                self.in_ch = in_ch
                for l in range(3):
                    if self.catDomain:
                        style_in = 8 + n_domain
                    else:
                        if l == 2 and not self.timephi:
                            style_in = 8 + n_domain
                        else:
                            style_in = 8
                            if self.weatherEmbedding or self.weatherEncode or self.add_z:
                                style_in += self.z_dim
                    style_var_fc_l = nn.Linear(style_in, in_ch[l])
                    style_mean_fc_l = nn.Linear(style_in, in_ch[l])
                    content_r_l = nn.Linear(in_ch[l], in_ch[l])
                    content_b_l = nn.Linear(in_ch[l], in_ch[l])
                    setattr(self, f"style_var_fc_{l}", style_var_fc_l)
                    setattr(self, f"style_mean_fc_{l}", style_mean_fc_l)
                    setattr(self, f"content_r_{l}", content_r_l)
                    setattr(self, f"content_b_{l}", content_b_l)
                
                self.generate_hidden_layers = nn.ModuleList(
                    [ nn.Linear(in_ch[i], in_ch[i]) for i in reversed(range(3)) ]
                )
                if self.catStyle:
                    self.generate_layers_0 = nn.ModuleList(
                        [ nn.Linear(in_ch[i+1]+8, in_ch[i+1]) for i in reversed(range(2)) ]
                    )
                self.generate_layers = nn.ModuleList(
                    [ nn.Linear(in_ch[i+1], in_ch[i]) for i in reversed(range(2)) ]
                )
                self.style_rgb_fc1 = nn.Linear(in_ch[0], in_ch[0]//2)
                self.style_rgb_fc2 = nn.Linear(in_ch[0]//2, 3)

        if self.delta_t:
            if self.delta_t_1x1:
                self.delta_t_estimator = Delta_estimator_1x1()
            else:
                self.delta_t_estimator = Delta_estimator()

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks, viewdirs_novel, for_mask=None, onlyContentStyle=False, input_phi=None, z=None, alpha=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        if not onlyContentStyle:
            v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
            s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
            colors = feat[..., 24 + 8 : 24 + 8 + 3]
            vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
            contents = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 64+128+256] # (level 0=64, 1=128, 2=256) # maybe use only level=2?
            style = feat[..., 24 + 8 + 3 + 1 + 64+128+256 : 24 + 8 + 3 + 1 + 64+128+256 + 8][:,0,:] # same for all views
            domain_vec = feat[..., 24 + 8 + 3 + 1 + 64+128+256 + 8 : 24 + 8 + 3 + 1 + 64+128+256 + 8 + 5][:,0,:] #TODO: domain_len hardcode:5
        else:
            contents = feat[..., :64+128+256] # (level 0=64, 1=128, 2=256) # maybe use only level=2?
            style = feat[..., 64+128+256:64+128+256 + 8][:,0,:] # same for all views
            domain_vec = feat[..., 64+128+256 + 8 : 64+128+256 + 8 + 5][:,0,:] #TODO: domain_len hardcode:5
        
        # content_feat
        contents_level = {}
        for l in range(3):
            begin, end = 0, 0
            for _l in range(l+1):
                if _l < l:
                    begin += self.in_ch[_l]
                    end += self.in_ch[_l]
                else:
                    end += self.in_ch[_l]
                
            contents_level[f"level_{l}"] = contents[...,begin:end]
        

        if not onlyContentStyle:
            occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
            viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

            if self.style3Dfeat:
                style3D_feat = feat[..., 24 + 8 + 3 + 1 + 64+128+256 + 8 + 5:] # (level 0=32, 1=64, 2=128)
                style3D_level = {}
                style3D_level['level_0'] = style3D_feat[...,:self.in_ch[0]]
                style3D_level['level_1'] = style3D_feat[...,self.in_ch[0]:self.in_ch[0]+self.in_ch[1]]
                style3D_level['level_2'] = style3D_feat[...,self.in_ch[0]+self.in_ch[1]:]
                

            ## Mean and variance of 2D features provide view-independent tokens
            var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
            var_mean = torch.cat(var_mean, dim=-1)
            var_mean = F.elu(self.var_mean_fc1(var_mean))
            var_mean = F.elu(self.var_mean_fc2(var_mean))

            ## Converting the input features to tokens (view-dependent) before self-attention
            tokens = F.elu(
                self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
            )
            tokens = torch.cat([tokens, var_mean], dim=1)

            if self.gene_mask != "None":
                if self.gene_mask == "interval":
                    pts_d = for_mask['pts_d'].reshape(-1,1,1)
                    pts_d_gt = for_mask['pts_d_gt'].unsqueeze(-1).repeat(1,S).reshape(-1,1,1)
                    need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                    need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1) #(bs*n_sample, 1, 1)
                    need_gene_mask[(pts_d-pts_d_gt > 0.2)*(pts_d_gt != 0)] = 0
                    need_gene_mask[(pts_d-pts_d_gt < -0.2)*(pts_d_gt != 0)] = 0
                    need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                    ray_pts_cnt = torch.zeros_like(pts_d_gt) # to cnt set to 1; else 0
                    ray_pts_cnt[(pts_d-pts_d_gt <= 0.2)*(pts_d-pts_d_gt >= 0)*(pts_d_gt != 0)] = 1
                    ray_pts_cnt[(pts_d-pts_d_gt >= -0.2)*(pts_d-pts_d_gt < 0)*(pts_d_gt != 0)] = 1
                
                elif self.gene_mask == "one_pt":
                    need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                    need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                    ray_pts_cnt = torch.ones_like(need_gene_mask)[...,0] # to cnt set to 1; else 0

                outputs = torch.cat([need_gene_mask, ray_pts_cnt.reshape(N,S,1)], -1)

                return outputs
                    

            ## Adding a new channel to mask for var_mean
            vis_mask = torch.cat(
                [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
            )
            ## If a point is not visible by any source view, force its masks to enabled
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

            ## Taking occ_masks into account, but remembering if there were any visibility before that
            mask_cloned = vis_mask.clone()
            vis_mask[:, :-1] *= occ_masks
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
            masks = vis_mask * mask_cloned

            ## Performing self-attention
            for layer in self.attn_layers:
                tokens, _ = layer(tokens, masks)

            ## Predicting sigma with an Auto-Encoder and MLP
            sigma_tokens = tokens[:, -1:]
            sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
            sigma_tokens = self.auto_enc(sigma_tokens)
            sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

            sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
            sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
            sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

            ## Concatenating positional encodings and predicting RGB weights
            rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
            rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
            rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
            rgb_w = self.rgb_fc3(rgb_tokens)
            rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)

            rgb = (colors * rgb_w).sum(1)

            rgb_sigma = torch.cat([rgb, sigma], -1)
            rgb_sigma = rgb_sigma.reshape(N, S, -1)
        
        ## content + style

        ## concat
        # x = (contents_level[f"level_2"] * rgb_w).sum(1)
        # for idx, style_layer in enumerate(self.generate_hidden_layers):
        #     if idx == 0:
        #         x = torch.relu(style_layer(x))
        #     elif idx == 1:
        #         x = torch.cat((x,domain_vec,style),dim=-1)
        #         x = torch.relu(style_layer(x))
        #     else:
        #         x = torch.cat((x,style),dim=-1)
        #         x = torch.relu(style_layer(x))

        # style_rgb = torch.sigmoid(self.style_rgb_fc(x))
        # style_rgb = style_rgb.reshape(N, S, -1)

        ## adain
        if self.style3Dfeat:
            x = (style3D_level[f"level_2"] * rgb_w).sum(1)
            for idx, l in enumerate(reversed(range(2))):
                hidden_layer_l = getattr(self, f"hidden_layer_{l}")
                style3D_var_l = getattr(self, f"style3D_var_{l}")
                style3D_mean_l = getattr(self, f"style3D_mean_{l}")
                style_input = (style3D_level[f"level_{l}"] * rgb_w).sum(1)

                x = hidden_layer_l(x)
                x = torch.sin(x * style3D_var_l(style_input) + style3D_mean_l(style_input))

            x = F.relu(self.style_rgb_fc1(x))
            style_rgb = torch.sigmoid(self.style_rgb_fc2(x))
            style_rgb = style_rgb.reshape(N, S, -1)
            
        else:
            if self.styleMLP_cls:
                domain_label = domain_vec.nonzero()[0,1]
                x = (contents_level[f"level_2"] * rgb_w).sum(1)
                for idx, l in enumerate(reversed(range(3))):
                    style_var_fc_l = getattr(self, f"style_var_fc_{l}")
                    style_mean_fc_l = getattr(self, f"style_mean_fc_{l}")

                    if self.catDomain:
                        style_input = torch.cat((style,domain_vec),dim=-1)
                    else:
                        if idx == 0: style_input = torch.cat((style,domain_vec),dim=-1)
                        else: style_input = style
                    x = x * style_var_fc_l[domain_label](style_input) + style_mean_fc_l[domain_label](style_input)
                    x = self.generate_hidden_layers[domain_label][idx](x)

                    content_r_l = getattr(self, f"content_r_{l}")
                    content_b_l = getattr(self, f"content_b_{l}")

                    contents_l = (contents_level[f"level_{l}"] * rgb_w).sum(1)
                    x = torch.sin(x * content_r_l[domain_label](contents_l) + content_b_l[domain_label](contents_l))
                    if idx != 2:
                        if self.catStyle:
                            x = torch.cat((x,style),dim=-1)
                            x = torch.relu(self.generate_layers_0[idx](x))
                        x = self.generate_layers[domain_label][idx](x)

                x = F.relu(self.style_rgb_fc1(x))
                style_rgb = torch.sigmoid(self.style_rgb_fc2(x))
                style_rgb = style_rgb.reshape(N, S, -1)

            elif self.styleLast:
                x = (contents_level[f"level_2"] * rgb_w).sum(1)
                for idx, l in enumerate(reversed(range(2))):
                    x = self.generate_layers[idx](x)

                    content_r_l = getattr(self, f"content_r_{l}")
                    content_b_l = getattr(self, f"content_b_{l}")

                    contents_l = (contents_level[f"level_{l}"] * rgb_w).sum(1)
                    x = torch.sin(x * content_r_l(contents_l) + content_b_l(contents_l))

                x = F.relu(self.content2style(x))
                x = torch.sin(x * self.style_r(style) + self.style_b(style))

                style_rgb = torch.sigmoid(self.style_rgb_fc1(x))
                style_rgb = style_rgb.reshape(N, S, -1)

            else:
                if self.timephi:
                    if input_phi == None:
                        phi = self.style2phi(style)
                        self.t = phi[0,0]
                    else:
                        phi = input_phi*torch.ones_like(style[:,0:1])
                    cosphi, sinphi = torch.cos(phi), torch.sin(phi)
                    cosphi_code, sinphi_code = self.cosphi_gamma(cosphi), self.sinphi_beta(sinphi)
                    if self.add_z:
                        assert z != None
                        z = z.repeat(cosphi_code.shape[0],1)
                        cosphi_code = self.z_cos_mapping(torch.cat((z,cosphi_code),dim=-1))
                        sinphi_code = self.z_sin_mapping(torch.cat((z,sinphi_code),dim=-1))
                        
                    if self.weatherEmbedding:
                        # domain_label = domain_vec.nonzero()[:,1]
                        domain_label = domain_vec.nonzero()[:,1] - 1
                        if domain_label[0].item() == -1:
                            domain_label = torch.ones_like(domain_label) * torch.randint(low=0,high=4,size=(1,)).cuda()

                        weather_embed_ = self.weather_embed(domain_label).squeeze()
                        cosphi_code = self.cosphiWeather_fcs(torch.cat((cosphi_code,weather_embed_),dim=-1))
                        sinphi_code = self.sinphiWeather_fcs(torch.cat((sinphi_code,weather_embed_),dim=-1))
                    if self.weatherEncode:
                        weather_encoding = self.weather_encode_fcs(style)
                        cosphi_code = self.cosphiWeather_fcs(torch.cat((cosphi_code,weather_encoding),dim=-1))
                        sinphi_code = self.sinphiWeather_fcs(torch.cat((sinphi_code,weather_encoding),dim=-1))
                        if self.weatherEncodeCls:
                            self.weather_pred = self.weatherEncoding_classifier(weather_encoding)

                x = (contents_level[f"level_2"] * rgb_w).sum(1) if not onlyContentStyle else contents_level[f"level_2"][:,0]
                for idx, l in enumerate(reversed(range(3))):
                    style_var_fc_l = getattr(self, f"style_var_fc_{l}")
                    style_mean_fc_l = getattr(self, f"style_mean_fc_{l}")

                    if not self.timephi:
                        if self.catDomain:
                            style_input = torch.cat((style,domain_vec),dim=-1)
                        else:
                            if idx == 0: style_input = torch.cat((style,domain_vec),dim=-1)
                            else: style_input = style
                    
                    # normalize x along ray
                    if self.adainNormalize:
                        x_var, x_mean = torch.var_mean(x.reshape(N,S,-1), dim=1, unbiased=False, keepdim=True)
                        x = (x.reshape(N,S,-1) - x_mean) / x_var**0.5
                        x = x.reshape(-1,x.shape[-1])

                    if self.timephi:
                        x = x * style_var_fc_l(cosphi_code) + style_mean_fc_l(sinphi_code)
                    else:
                        x = x * style_var_fc_l(style_input) + style_mean_fc_l(style_input)
                    x = self.generate_hidden_layers[idx](x)

                    content_r_l = getattr(self, f"content_r_{l}")
                    content_b_l = getattr(self, f"content_b_{l}")

                    contents_l = (contents_level[f"level_{l}"] * rgb_w).sum(1) if not onlyContentStyle else contents_level[f"level_{l}"][:,0]
                    x = torch.sin(x * content_r_l(contents_l) + content_b_l(contents_l))
                    if idx != 2:
                        if self.catStyle:
                            x = torch.cat((x,style),dim=-1)
                            x = torch.relu(self.generate_layers_0[idx](x))
                        x = self.generate_layers[idx](x)

                x = F.relu(self.style_rgb_fc1(x))
                style_rgb = torch.sigmoid(self.style_rgb_fc2(x))
                style_rgb = style_rgb.reshape(N, S, -1)

        if onlyContentStyle:
            return None, style_rgb
        else:
            return rgb_sigma, style_rgb


class RendererStyle_one(nn.Module):
    def __init__(self, nb_samples_per_ray, gene_mask="None", n_domain=5, catDomain=False, catStyle=False, styleMLP_cls=False, styleLast=False, adainNormalize=False, style3Dfeat=False, timephi=False, weatherEmbedding=False, weatherEncode=False, weatherEncodeCls=False, add_z=False,
                 zInputStyle=False, delta_t=False, delta_t_1x1=False):
        super(RendererStyle_one, self).__init__()

        self.gene_mask = gene_mask
        self.catDomain = catDomain
        self.catStyle = catStyle
        self.styleMLP_cls = styleMLP_cls
        self.styleLast = styleLast
        self.adainNormalize = adainNormalize
        self.style3Dfeat = style3Dfeat
        self.timephi = timephi
        self.weatherEmbedding = weatherEmbedding
        self.weatherEncode = weatherEncode
        self.weatherEncodeCls = weatherEncodeCls
        self.add_z = add_z

        self.zInputStyle = zInputStyle
        self.delta_t = delta_t
        self.delta_t_1x1 = delta_t_1x1

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        if self.timephi:
            self.style2phi = style2phi_network()
            self.cosphi_gamma = phi2code_network()
            self.sinphi_beta = phi2code_network()
            if self.weatherEmbedding:
                # self.weather_embed = nn.Embedding(5,8)
                self.weather_embed = nn.Embedding(4,8)
                self.cosphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
                self.sinphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
            if self.weatherEncode:
                self.weather_encode_fcs = style2weather_network()
                self.cosphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
                self.sinphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
            if self.weatherEncodeCls:
                assert self.weatherEncode == True
                self.weatherEncoding_classifier = nn.Sequential(nn.Linear(8,4),nn.ReLU(),nn.Linear(4,4),nn.ReLU(),nn.Linear(4,4))
        
        if self.zInputStyle:
            self.style2remain = style2remain_network()
            self.z_dim = 7

        if self.add_z:
            z_cos_in = 8 + self.z_dim
            self.z_cos_mapping = nn.Sequential(
                nn.Linear(z_cos_in, z_cos_in), nn.ReLU(), nn.Linear(z_cos_in, z_cos_in), nn.ReLU()
            )
            self.z_sin_mapping = nn.Sequential(
                nn.Linear(z_cos_in, z_cos_in), nn.ReLU(), nn.Linear(z_cos_in, z_cos_in), nn.ReLU()
            )

        in_ch = [8, 16, 32]
        self.in_ch = in_ch
        for l in range(3):
            style_in = 8
            if self.weatherEmbedding or self.weatherEncode or self.add_z:
                style_in += self.z_dim
            style_var_fc_l = nn.Linear(style_in, in_ch[l])
            style_mean_fc_l = nn.Linear(style_in, in_ch[l])
            content_r_l = nn.Linear(in_ch[l], in_ch[l])
            content_b_l = nn.Linear(in_ch[l], in_ch[l])
            setattr(self, f"style_var_fc_{l}", style_var_fc_l)
            setattr(self, f"style_mean_fc_{l}", style_mean_fc_l)
            setattr(self, f"content_r_{l}", content_r_l)
            setattr(self, f"content_b_{l}", content_b_l)
        
        self.generate_hidden_layers = nn.ModuleList(
            [ nn.Linear(in_ch[i], in_ch[i]) for i in reversed(range(3)) ]
        )
        if self.catStyle:
            self.generate_layers_0 = nn.ModuleList(
                [ nn.Linear(in_ch[i+1]+8, in_ch[i+1]) for i in reversed(range(2)) ]
            )
        self.generate_layers = nn.ModuleList(
            [ nn.Linear(in_ch[i+1], in_ch[i]) for i in reversed(range(2)) ]
        )
        self.style_rgb_fc1 = nn.Linear(in_ch[0], in_ch[0]//2)
        self.style_rgb_fc2 = nn.Linear(in_ch[0]//2, in_ch[0]//2)
        self.style_rgb_fc3 = nn.Linear(in_ch[0]//2, 3)

        if self.delta_t:
            if self.delta_t_1x1:
                self.delta_t_estimator = Delta_estimator_1x1()
            else:
                self.delta_t_estimator = Delta_estimator()

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks, viewdirs_novel, for_mask=None, onlyContentStyle=False, input_phi=None, z=None, alpha=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        if not onlyContentStyle:
            v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
            s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
            colors = feat[..., 24 + 8 : 24 + 8 + 3]
            vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
            contents = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 8+16+32] # (level 0=64, 1=128, 2=256) # maybe use only level=2?
            style = feat[..., 24 + 8 + 3 + 1 + 8+16+32 : 24 + 8 + 3 + 1 + 8+16+32 + 8][:,0,:] # same for all views
            domain_vec = feat[..., 24 + 8 + 3 + 1 + 8+16+32 + 8 : 24 + 8 + 3 + 1 + 8+16+32 + 8 + 5][:,0,:] #TODO: domain_len hardcode:5
        else:
            contents = feat[..., :8+16+32] # (level 0=64, 1=128, 2=256) # maybe use only level=2?
            style = feat[..., 8+16+32:8+16+32 + 8][:,0,:] # same for all views
            domain_vec = feat[..., 8+16+32 + 8 : 8+16+32 + 8 + 5][:,0,:] #TODO: domain_len hardcode:5
        
        # content_feat
        contents_level = {}
        for l in range(3):
            begin, end = 0, 0
            for _l in range(l+1):
                if _l < l:
                    begin += self.in_ch[_l]
                    end += self.in_ch[_l]
                else:
                    end += self.in_ch[_l]
                
            contents_level[f"level_{l}"] = contents[...,begin:end]
        

        if not onlyContentStyle:
            occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
            viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

            ## Mean and variance of 2D features provide view-independent tokens
            var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
            var_mean = torch.cat(var_mean, dim=-1)
            var_mean = F.elu(self.var_mean_fc1(var_mean))
            var_mean = F.elu(self.var_mean_fc2(var_mean))

            ## Converting the input features to tokens (view-dependent) before self-attention
            tokens = F.elu(
                self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
            )
            tokens = torch.cat([tokens, var_mean], dim=1)

            if self.gene_mask != "None":
                if self.gene_mask == "interval":
                    pts_d = for_mask['pts_d'].reshape(-1,1,1)
                    pts_d_gt = for_mask['pts_d_gt'].unsqueeze(-1).repeat(1,S).reshape(-1,1,1)
                    need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                    need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1) #(bs*n_sample, 1, 1)
                    need_gene_mask[(pts_d-pts_d_gt > 0.2)*(pts_d_gt != 0)] = 0
                    need_gene_mask[(pts_d-pts_d_gt < -0.2)*(pts_d_gt != 0)] = 0
                    need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                    ray_pts_cnt = torch.zeros_like(pts_d_gt) # to cnt set to 1; else 0
                    ray_pts_cnt[(pts_d-pts_d_gt <= 0.2)*(pts_d-pts_d_gt >= 0)*(pts_d_gt != 0)] = 1
                    ray_pts_cnt[(pts_d-pts_d_gt >= -0.2)*(pts_d-pts_d_gt < 0)*(pts_d_gt != 0)] = 1
                
                elif self.gene_mask == "one_pt":
                    need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                    need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                    ray_pts_cnt = torch.ones_like(need_gene_mask)[...,0] # to cnt set to 1; else 0

                outputs = torch.cat([need_gene_mask, ray_pts_cnt.reshape(N,S,1)], -1)

                return outputs
                    

            ## Adding a new channel to mask for var_mean
            vis_mask = torch.cat(
                [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
            )
            ## If a point is not visible by any source view, force its masks to enabled
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

            ## Taking occ_masks into account, but remembering if there were any visibility before that
            mask_cloned = vis_mask.clone()
            vis_mask[:, :-1] *= occ_masks
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
            masks = vis_mask * mask_cloned

            ## Performing self-attention
            for layer in self.attn_layers:
                tokens, _ = layer(tokens, masks)

            ## Predicting sigma with an Auto-Encoder and MLP
            sigma_tokens = tokens[:, -1:]
            sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
            sigma_tokens = self.auto_enc(sigma_tokens)
            sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

            sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
            sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
            sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

            ## Concatenating positional encodings and predicting RGB weights
            rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
            rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
            rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
            rgb_w = self.rgb_fc3(rgb_tokens)
            rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)

            rgb = (colors * rgb_w).sum(1)

            rgb_sigma = torch.cat([rgb, sigma], -1)
            rgb_sigma = rgb_sigma.reshape(N, S, -1)
        
        ## content + style

        ## concat
        # x = (contents_level[f"level_2"] * rgb_w).sum(1)
        # for idx, style_layer in enumerate(self.generate_hidden_layers):
        #     if idx == 0:
        #         x = torch.relu(style_layer(x))
        #     elif idx == 1:
        #         x = torch.cat((x,domain_vec,style),dim=-1)
        #         x = torch.relu(style_layer(x))
        #     else:
        #         x = torch.cat((x,style),dim=-1)
        #         x = torch.relu(style_layer(x))

        # style_rgb = torch.sigmoid(self.style_rgb_fc(x))
        # style_rgb = style_rgb.reshape(N, S, -1)

        ## adain
        if self.style3Dfeat:
            x = (style3D_level[f"level_2"] * rgb_w).sum(1)
            for idx, l in enumerate(reversed(range(2))):
                hidden_layer_l = getattr(self, f"hidden_layer_{l}")
                style3D_var_l = getattr(self, f"style3D_var_{l}")
                style3D_mean_l = getattr(self, f"style3D_mean_{l}")
                style_input = (style3D_level[f"level_{l}"] * rgb_w).sum(1)

                x = hidden_layer_l(x)
                x = torch.sin(x * style3D_var_l(style_input) + style3D_mean_l(style_input))

            x = F.relu(self.style_rgb_fc1(x))
            style_rgb = torch.sigmoid(self.style_rgb_fc2(x))
            style_rgb = style_rgb.reshape(N, S, -1)
            
        else:
            if self.timephi:
                if input_phi == None:
                    phi = self.style2phi(style)
                    self.t = phi[0,0]
                else:
                    phi = input_phi*torch.ones_like(style[:,0:1])
                cosphi, sinphi = torch.cos(phi), torch.sin(phi)
                cosphi_code, sinphi_code = self.cosphi_gamma(cosphi), self.sinphi_beta(sinphi)
                if self.add_z:
                    assert z != None
                    z = z.repeat(cosphi_code.shape[0],1)
                    cosphi_code = self.z_cos_mapping(torch.cat((z,cosphi_code),dim=-1))
                    sinphi_code = self.z_sin_mapping(torch.cat((z,sinphi_code),dim=-1))
                    
                if self.weatherEmbedding:
                    # domain_label = domain_vec.nonzero()[:,1]
                    domain_label = domain_vec.nonzero()[:,1] - 1
                    if domain_label[0].item() == -1:
                        domain_label = torch.ones_like(domain_label) * torch.randint(low=0,high=4,size=(1,)).cuda()

                    weather_embed_ = self.weather_embed(domain_label).squeeze()
                    cosphi_code = self.cosphiWeather_fcs(torch.cat((cosphi_code,weather_embed_),dim=-1))
                    sinphi_code = self.sinphiWeather_fcs(torch.cat((sinphi_code,weather_embed_),dim=-1))
                if self.weatherEncode:
                    weather_encoding = self.weather_encode_fcs(style)
                    cosphi_code = self.cosphiWeather_fcs(torch.cat((cosphi_code,weather_encoding),dim=-1))
                    sinphi_code = self.sinphiWeather_fcs(torch.cat((sinphi_code,weather_encoding),dim=-1))
                    if self.weatherEncodeCls:
                        self.weather_pred = self.weatherEncoding_classifier(weather_encoding)

            x = (contents_level[f"level_2"] * rgb_w).sum(1) if not onlyContentStyle else contents_level[f"level_2"][:,0]
            for idx, l in enumerate(reversed(range(3))):
                style_var_fc_l = getattr(self, f"style_var_fc_{l}")
                style_mean_fc_l = getattr(self, f"style_mean_fc_{l}")

                if not self.timephi:
                    if self.catDomain:
                        style_input = torch.cat((style,domain_vec),dim=-1)
                    else:
                        if idx == 0: style_input = torch.cat((style,domain_vec),dim=-1)
                        else: style_input = style
                
                # normalize x along ray
                if self.adainNormalize:
                    x_var, x_mean = torch.var_mean(x.reshape(N,S,-1), dim=1, unbiased=False, keepdim=True)
                    x = (x.reshape(N,S,-1) - x_mean) / x_var**0.5
                    x = x.reshape(-1,x.shape[-1])

                if self.timephi:
                    x = x * style_var_fc_l(cosphi_code) + style_mean_fc_l(sinphi_code)
                else:
                    x = x * style_var_fc_l(style_input) + style_mean_fc_l(style_input)
                x = self.generate_hidden_layers[idx](x)

                content_r_l = getattr(self, f"content_r_{l}")
                content_b_l = getattr(self, f"content_b_{l}")

                contents_l = (contents_level[f"level_{l}"] * rgb_w).sum(1) if not onlyContentStyle else contents_level[f"level_{l}"][:,0]
                x = torch.sin(x * content_r_l(contents_l) + content_b_l(contents_l))
                if idx != 2:
                    if self.catStyle:
                        x = torch.cat((x,style),dim=-1)
                        x = torch.relu(self.generate_layers_0[idx](x))
                    x = self.generate_layers[idx](x)

            x = F.relu(self.style_rgb_fc1(x))
            x = F.relu(self.style_rgb_fc2(x))
            style_rgb = torch.sigmoid(self.style_rgb_fc3(x))
            style_rgb = style_rgb.reshape(N, S, -1)

        if onlyContentStyle:
            return None, style_rgb
        else:
            return rgb_sigma, style_rgb

def fourier_feature(code, z=None, gaussian=False, L=8, which=None):
    if gaussian:
        assert z != None
        cos = torch.cos(2*math.pi*z*code)
        sin = torch.sin(2*math.pi*z*code)

        return torch.cat((cos,sin),dim=-1)
    else:
        e = []
        if which == 'cos':
            for l in range(L):
                # e.append(torch.cos(code*(2**l)))
                e.append(torch.cos(code/(2**l)))
        elif which == 'sin':
            for l in range(L):
                # e.append(torch.sin(code*(2**l)))
                e.append(torch.sin(code/(2**l)))
        elif which == 'all':
            for l in range(L):
                e.append(torch.cos(code*(2**l)))
                e.append(torch.sin(code*(2**l)))

        return torch.cat(e,dim=-1)


class Conv2dBlock(nn.Module):
    def __init__(self, input_dim ,output_dim, kernel_size, stride,
                 padding=0, norm='none', activation='relu', pad_type='zero'):
        super(Conv2dBlock, self).__init__()
        self.use_bias = True
        # initialize padding
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'batch':
            self.norm = nn.BatchNorm2d(norm_dim)
        elif norm == 'instance':
            #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
            self.norm = nn.InstanceNorm2d(norm_dim)
        elif norm == 'layer':
            self.norm = LayerNorm(norm_dim)
        elif norm == 'adain':
            self.norm = AdaptiveInstanceNorm2d(norm_dim)
        elif norm == 'none' or norm == 'spectral':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # initialize convolution
        if norm == 'spectral':
            self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
        else:
            self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)

    def forward(self, x):
        x = self.conv(self.pad(x))
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x

class Delta_estimator(nn.Module):
    def __init__(self, n_downsample=4, input_dim=3*2, dim=64, norm='instance', activ='lrelu', pad_type='reflect'):
        super(Delta_estimator, self).__init__()
        self.model = []
        self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
        for i in range(2):
            self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
            dim *= 2
        for i in range(n_downsample - 2):
            self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
        self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
        self.model += [nn.Conv2d(dim, 2, 1, 1, 0)]
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return torch.tanh(self.model(x))*2 #[-2,2]

class Delta_estimator_1x1(nn.Module):
    def __init__(self, n_downsample=4, input_dim=3*2, dim=6, norm='instance', activ='lrelu', pad_type='reflect'):
        super(Delta_estimator_1x1, self).__init__()
        self.model = []
        self.model += [nn.Conv2d(input_dim, dim, 1, 1, 0), nn.LeakyReLU(0.2, inplace=True)]
        for i in range(2):
            self.model += [nn.Conv2d(dim, dim-2, 1, 1, 0), nn.LeakyReLU(0.2, inplace=True)]
            dim -= 2
        for i in range(n_downsample - 2):
            self.model += [nn.Conv2d(dim, dim, 1, 1, 0), nn.LeakyReLU(0.2, inplace=True)]
        self.model += [nn.Conv2d(dim, 2, 1, 1, 0)]
        self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
        # self.model += [nn.Conv2d(dim, 2, 1, 1, 0)]
        self.model = nn.Sequential(*self.model)

    def forward(self, x):
        return torch.tanh(self.model(x))*2 #[-2,2]

class mapping_net(nn.Module):
    def __init__(self):
        super(mapping_net, self).__init__()

        self.linears = nn.ModuleList(
                        [nn.Linear(7, 7) for i in range(3)]
                    )
        self.final_linear = nn.Linear(7, 7)
        self.lrelu = nn.LeakyReLU()
    
    def forward(self,x):
        for fc in self.linears:
            x = self.lrelu(fc(x))
        x = self.final_linear(x)
        return x

class rgb2phi_network(nn.Module):
    def __init__(self):
        super(rgb2phi_network, self).__init__()

        self.linears = nn.ModuleList(
                        [nn.Linear(3, 2)] + [nn.Linear(2, 2) for i in range(3)]
                    )
        self.final_linear = nn.Linear(2, 2)
        self.lrelu = nn.LeakyReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self,x):
        for fc in self.linears:
            x = self.lrelu(fc(x))
        x = self.sigmoid(self.final_linear(x))*2 - 1 # [-1,1]

        return x

class RendererStyle2branch(nn.Module):
    def __init__(self, nb_samples_per_ray, gene_mask="None", n_domain=5, style3Dfeat=False, timephi=False, weatherEmbedding=False, STB_ver='v0', wo_z=False, phi_final_actv='lrelu', use_fourier_feature=False, branch2_noPhi=False, z_dim=8, fourier_phi=False, phiNoCosSin=False, 
                 zInputStyle=False, update_z=False, delta_t=False, delta_t_1x1=False, z_zInputStyle_Fuse=False, rgb2t=False):
        super(RendererStyle2branch, self).__init__()

        self.gene_mask = gene_mask
        self.style3Dfeat = style3Dfeat
        self.timephi = timephi
        self.weatherEmbedding = weatherEmbedding
        self.STB_ver = STB_ver
        self.wo_z = wo_z
        self.use_fourier_feature = use_fourier_feature
        self.branch2_noPhi = branch2_noPhi
        self.z_dim = z_dim
        self.fourier_phi = fourier_phi
        self.phiNoCosSin = phiNoCosSin
        self.zInputStyle = zInputStyle
        self.update_z = update_z
        self.delta_t = delta_t
        self.delta_t_1x1 = delta_t_1x1
        self.z_zInputStyle_Fuse = z_zInputStyle_Fuse
        self.rgb2t = rgb2t

        self.dim = 32
        self.attn_token_gen = nn.Linear(24 + 1 + 8, self.dim)

        ## Self-Attention Settings
        d_inner = self.dim
        n_head = 4
        d_k = self.dim // n_head
        d_v = self.dim // n_head
        num_layers = 4
        self.attn_layers = nn.ModuleList(
            [
                EncoderLayer(self.dim, d_inner, n_head, d_k, d_v)
                for i in range(num_layers)
            ]
        )

        ## Processing the mean and variance of input features
        self.var_mean_fc1 = nn.Linear(16, self.dim)
        self.var_mean_fc2 = nn.Linear(self.dim, self.dim)

        ## Setting mask of var_mean always enabled
        self.var_mean_mask = torch.tensor([1]).cuda()
        self.var_mean_mask.requires_grad = False

        ## For aggregating data along ray samples
        self.auto_enc = ConvAutoEncoder(self.dim, nb_samples_per_ray)

        self.sigma_fc1 = nn.Linear(self.dim, self.dim)
        self.sigma_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.sigma_fc3 = nn.Linear(self.dim // 2, 1)

        self.rgb_fc1 = nn.Linear(self.dim + 9, self.dim)
        self.rgb_fc2 = nn.Linear(self.dim, self.dim // 2)
        self.rgb_fc3 = nn.Linear(self.dim // 2, 1)

        # styleMLP
        if self.update_z:
            self.z = nn.Parameter(torch.randn((1,self.z_dim)))
        if self.timephi:
            self.style2phi = style2phi_network()
            # if self.fourier_phi:
            #     self.cosphi_gamma = nn.Sequential(nn.Linear(8,8),nn.LeakyReLU())
            #     self.sinphi_beta = nn.Sequential(nn.Linear(8,8),nn.LeakyReLU())
            # else:
            phi_dim = 3 if self.fourier_phi else 1
            self.cosphi_gamma = phi2code_network(final_actv=phi_final_actv,input_dim=phi_dim)
            self.sinphi_beta = phi2code_network(final_actv=phi_final_actv,input_dim=phi_dim)
            if self.weatherEmbedding:
                self.weather_embed = nn.Embedding(5,8)
                self.cosphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
                self.sinphiWeather_fcs = nn.Sequential(nn.Linear(16,16),nn.ReLU())
                self.cosphiWeather_special_fcs = nn.Sequential(nn.Linear(24,24),nn.ReLU())
                self.sinphiWeather_special_fcs = nn.Sequential(nn.Linear(24,24),nn.ReLU())

        if self.style3Dfeat:
            # in_ch = [32, 64, 128]
            in_ch = [8, 16, 32]
            self.in_ch = in_ch
            self.hidden_layer_1 = nn.Linear(in_ch[2], in_ch[1])
            self.hidden_layer_0 = nn.Linear(in_ch[1], in_ch[0])

            self.style3D_var_1 = nn.Linear(in_ch[1], in_ch[1])
            self.style3D_mean_1 = nn.Linear(in_ch[1], in_ch[1])
            self.style3D_var_0 = nn.Linear(in_ch[0], in_ch[0])
            self.style3D_mean_0 = nn.Linear(in_ch[0], in_ch[0])

            self.style_rgb_fc1 = nn.Linear(in_ch[0], in_ch[0]//2)
            self.style_rgb_fc2= nn.Linear(in_ch[0]//2, 3)
        else:
            ## common style MLP
            in_ch = [8, 16, 32]
            self.in_ch = in_ch
            for l in range(3):
                style_in = 8
                if self.weatherEmbedding:
                    style_in += 8
                style_var_fc_l = nn.Linear(style_in, in_ch[l])
                style_mean_fc_l = nn.Linear(style_in, in_ch[l])
                content_r_l = nn.Linear(in_ch[l], in_ch[l])
                content_b_l = nn.Linear(in_ch[l], in_ch[l])
                setattr(self, f"style_var_fc_{l}", style_var_fc_l)
                setattr(self, f"style_mean_fc_{l}", style_mean_fc_l)
                setattr(self, f"content_r_{l}", content_r_l)
                setattr(self, f"content_b_{l}", content_b_l)
            
            self.generate_hidden_layers = nn.ModuleList(
                [ nn.Linear(in_ch[i], in_ch[i]) for i in reversed(range(3)) ]
            )
            self.generate_layers = nn.ModuleList(
                [ nn.Linear(in_ch[i+1], in_ch[i]) for i in reversed(range(2)) ]
            )

            ## special style MLP
            # if self.STB_ver == 'v1' or self.STB_ver == 'v2':
            self.specialFeat2weight = nn.Sequential(
                nn.Linear(8+16+32, 32), nn.ReLU(), nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8), nn.Sigmoid()
            )

                    
            if self.zInputStyle:
                self.style2remain = style2remain_network()
                self.z_dim = 7
            
            if self.z_zInputStyle_Fuse:
                self.z2commonSpace = mapping_net()

            in_ch = [8, 16, 32]
            self.in_ch = in_ch
            if self.wo_z or self.branch2_noPhi: z_cos_in = 8
            else: z_cos_in = 8 + self.z_dim
            if not use_fourier_feature:
                self.z_cos_mapping = nn.Sequential(
                    nn.Linear(z_cos_in, z_cos_in), nn.ReLU(), nn.Linear(z_cos_in, z_cos_in), nn.ReLU()
                )
                self.z_sin_mapping = nn.Sequential(
                    nn.Linear(z_cos_in, z_cos_in), nn.ReLU(), nn.Linear(z_cos_in, z_cos_in), nn.ReLU()
                )
            for l in range(3):
                style_in = 8
                if not self.wo_z and not self.branch2_noPhi:
                    style_in += self.z_dim
                if self.weatherEmbedding:
                    style_in += 8
                style_var_fc_l = nn.Linear(style_in, in_ch[l])
                style_mean_fc_l = nn.Linear(style_in, in_ch[l])
                content_r_l = nn.Linear(in_ch[l], in_ch[l])
                content_b_l = nn.Linear(in_ch[l], in_ch[l])
                setattr(self, f"styleSpecial_var_fc_{l}", style_var_fc_l)
                setattr(self, f"styleSpecial_mean_fc_{l}", style_mean_fc_l)
                setattr(self, f"contentSpecial_r_{l}", content_r_l)
                setattr(self, f"contentSpecial_b_{l}", content_b_l)
            
            self.generate_hidden_layers_special = nn.ModuleList(
                [ nn.Linear(in_ch[i], in_ch[i]) for i in reversed(range(3)) ]
            )
            self.generate_layers_special = nn.ModuleList(
                [ nn.Linear(in_ch[i+1], in_ch[i]) for i in reversed(range(2)) ]
            )

            ## shared RGB decoder
            self.style_rgb_fc1 = nn.Linear(in_ch[0], in_ch[0]//2)
            self.style_rgb_fc2 = nn.Linear(in_ch[0]//2, in_ch[0]//2)
            self.style_rgb_fc3 = nn.Linear(in_ch[0]//2, 3)

        if self.delta_t:
            if self.delta_t_1x1:
                self.delta_t_estimator = Delta_estimator_1x1()
            else:
                self.delta_t_estimator = Delta_estimator()

        if self.rgb2t:
            self.rgb2phi_net = rgb2phi_network()

        ## Initialization
        self.sigma_fc3.apply(weights_init)

    def forward(self, viewdirs, feat, occ_masks, viewdirs_novel, for_mask=None, onlyContentStyle=False, input_phi=None, return_what_style="common", z=None, alpha=None):
        ## Viewing samples regardless of batch or ray
        # feat: (batch_size, sample points along a ray, nb_view, dim)
        N, S, V = feat.shape[:3]
        feat = feat.view(-1, *feat.shape[2:])
        if not onlyContentStyle:
            v_feat = feat[..., :24] # feature 3D (level 0, 1, 2)
            s_feat = feat[..., 24 : 24 + 8] # feature 2D (level 0)
            colors = feat[..., 24 + 8 : 24 + 8 + 3]
            vis_mask = feat[..., 24 + 8 + 3 : 24 + 8 + 3 + 1].detach()
            # contents = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 64+128+256] # (level 0=64, 1=128, 2=256) # maybe use only level=2?
            # style = feat[..., 24 + 8 + 3 + 1 + 64+128+256 : 24 + 8 + 3 + 1 + 64+128+256 + 8][:,0,:] # same for all views
            # domain_vec = feat[..., 24 + 8 + 3 + 1 + 64+128+256 + 8 : 24 + 8 + 3 + 1 + 64+128+256 + 8 + 5][:,0,:] #TODO: domain_len hardcode:5
            common_contents = feat[..., 24 + 8 + 3 + 1 : 24 + 8 + 3 + 1 + 8+16+32] # (level 0=8, 1=16, 2=32) # maybe use only level=2?
            special_contents = feat[..., 24 + 8 + 3 + 1 + 8+16+32 : 24 + 8 + 3 + 1 + 8+16+32 + 8+16+32] # (level 0=8, 1=16, 2=32) # maybe use only level=2?
            style = feat[..., 24 + 8 + 3 + 1 + 8+16+32 + 8+16+32 : 24 + 8 + 3 + 1 + 8+16+32 + 8+16+32 + 8][:,0,:] # same for all views
            domain_vec = feat[..., 24 + 8 + 3 + 1 + 8+16+32 + 8+16+32 + 8 : 24 + 8 + 3 + 1 + 8+16+32 + 8+16+32 + 8 + 5][:,0,:] # same for all views
        else:
            common_contents = feat[..., :8+16+32] # (level 0=64, 1=128, 2=256) # maybe use only level=2?
            special_contents = feat[..., 8+16+32:8+16+32 + 8+16+32] # (level 0=64, 1=128, 2=256)
            style = feat[..., 8+16+32 + 8+16+32:8+16+32 + 8+16+32 + 8][:,0,:] # same for all views
            domain_vec = feat[..., 8+16+32 + 8+16+32 + 8:8+16+32 + 8+16+32 + 8 + 5][:,0,:]

        common_contents_level = {}
        special_contents_level = {}
        for l in range(3):
            begin, end = 0, 0
            for _l in range(l+1):
                if _l < l:
                    begin += self.in_ch[_l]
                    end += self.in_ch[_l]
                else:
                    end += self.in_ch[_l]
            common_contents_level[f"level_{l}"] = common_contents[...,begin:end]
            special_contents_level[f"level_{l}"] = special_contents[...,begin:end]
        

        if not onlyContentStyle:
            occ_masks = occ_masks.view(-1, *occ_masks.shape[2:])
            viewdirs = viewdirs.view(-1, *viewdirs.shape[2:])

            if self.style3Dfeat:
                style3D_feat = feat[..., 24 + 8 + 3 + 1 + 64+128+256 + 8 + 5:] # (level 0=32, 1=64, 2=128)
                style3D_level = {}
                style3D_level['level_0'] = style3D_feat[...,:self.in_ch[0]]
                style3D_level['level_1'] = style3D_feat[...,self.in_ch[0]:self.in_ch[0]+self.in_ch[1]]
                style3D_level['level_2'] = style3D_feat[...,self.in_ch[0]+self.in_ch[1]:]
                

            ## Mean and variance of 2D features provide view-independent tokens
            var_mean = torch.var_mean(s_feat, dim=1, unbiased=False, keepdim=True)
            var_mean = torch.cat(var_mean, dim=-1)
            var_mean = F.elu(self.var_mean_fc1(var_mean))
            var_mean = F.elu(self.var_mean_fc2(var_mean))

            ## Converting the input features to tokens (view-dependent) before self-attention
            tokens = F.elu(
                self.attn_token_gen(torch.cat([v_feat, vis_mask, s_feat], dim=-1))
            )
            tokens = torch.cat([tokens, var_mean], dim=1)

            if self.gene_mask != "None":
                if self.gene_mask == "interval":
                    pts_d = for_mask['pts_d'].reshape(-1,1,1)
                    pts_d_gt = for_mask['pts_d_gt'].unsqueeze(-1).repeat(1,S).reshape(-1,1,1)
                    need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                    need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1) #(bs*n_sample, 1, 1)
                    need_gene_mask[(pts_d-pts_d_gt > 0.2)*(pts_d_gt != 0)] = 0
                    need_gene_mask[(pts_d-pts_d_gt < -0.2)*(pts_d_gt != 0)] = 0
                    need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                    ray_pts_cnt = torch.zeros_like(pts_d_gt) # to cnt set to 1; else 0
                    ray_pts_cnt[(pts_d-pts_d_gt <= 0.2)*(pts_d-pts_d_gt >= 0)*(pts_d_gt != 0)] = 1
                    ray_pts_cnt[(pts_d-pts_d_gt >= -0.2)*(pts_d-pts_d_gt < 0)*(pts_d_gt != 0)] = 1
                
                elif self.gene_mask == "one_pt":
                    need_gene_mask = torch.zeros_like(vis_mask[:,0:1,:])
                    need_gene_mask = need_gene_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.masked_fill(occ_masks.sum(dim=1, keepdims=True) == 0, 1)
                    need_gene_mask = need_gene_mask.squeeze(-1).reshape(N,S,1).repeat(1,1,3)

                    ray_pts_cnt = torch.ones_like(need_gene_mask)[...,0] # to cnt set to 1; else 0

                outputs = torch.cat([need_gene_mask, ray_pts_cnt.reshape(N,S,1)], -1)

                return outputs
                    

            ## Adding a new channel to mask for var_mean
            vis_mask = torch.cat(
                [vis_mask, self.var_mean_mask.view(1, 1, 1).expand(N * S, -1, -1)], dim=1
            )
            ## If a point is not visible by any source view, force its masks to enabled
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)

            ## Taking occ_masks into account, but remembering if there were any visibility before that
            mask_cloned = vis_mask.clone()
            vis_mask[:, :-1] *= occ_masks
            vis_mask = vis_mask.masked_fill(vis_mask.sum(dim=1, keepdims=True) == 1, 1)
            masks = vis_mask * mask_cloned

            ## Performing self-attention
            for layer in self.attn_layers:
                tokens, _ = layer(tokens, masks)

            ## Predicting sigma with an Auto-Encoder and MLP
            sigma_tokens = tokens[:, -1:]
            sigma_tokens = sigma_tokens.view(N, S, self.dim).transpose(1, 2)
            sigma_tokens = self.auto_enc(sigma_tokens)
            sigma_tokens = sigma_tokens.transpose(1, 2).reshape(N * S, 1, self.dim)

            sigma_tokens = F.elu(self.sigma_fc1(sigma_tokens))
            sigma_tokens = F.elu(self.sigma_fc2(sigma_tokens))
            sigma = torch.relu(self.sigma_fc3(sigma_tokens[:, 0]))

            ## Concatenating positional encodings and predicting RGB weights
            rgb_tokens = torch.cat([tokens[:, :-1], viewdirs], dim=-1)
            rgb_tokens = F.elu(self.rgb_fc1(rgb_tokens))
            rgb_tokens = F.elu(self.rgb_fc2(rgb_tokens))
            rgb_w = self.rgb_fc3(rgb_tokens)
            rgb_w = masked_softmax(rgb_w, masks[:, :-1], dim=1)

            rgb = (colors * rgb_w).sum(1)

            rgb_sigma = torch.cat([rgb, sigma], -1)
            rgb_sigma = rgb_sigma.reshape(N, S, -1)
        
        ## content + style
        ## adain
        if self.style3Dfeat:
            x = (style3D_level[f"level_2"] * rgb_w).sum(1)
            for idx, l in enumerate(reversed(range(2))):
                hidden_layer_l = getattr(self, f"hidden_layer_{l}")
                style3D_var_l = getattr(self, f"style3D_var_{l}")
                style3D_mean_l = getattr(self, f"style3D_mean_{l}")
                style_input = (style3D_level[f"level_{l}"] * rgb_w).sum(1)

                x = hidden_layer_l(x)
                x = torch.sin(x * style3D_var_l(style_input) + style3D_mean_l(style_input))

            x = F.relu(self.style_rgb_fc1(x))
            style_rgb = torch.sigmoid(self.style_rgb_fc2(x))
            style_rgb = style_rgb.reshape(N, S, -1)
            
        else:
            if self.timephi:
                if input_phi == None:
                    phi = self.style2phi(style)
                    self.t = phi[0,0]
                else:
                    phi = input_phi*torch.ones_like(style[:,0:1])
                if self.phiNoCosSin:
                    cosphi, sinphi = phi, phi
                else:
                    if self.fourier_phi:
                        # cosphi, sinphi = fourier_feature(phi, L=4, which='cos'), fourier_feature(phi, L=4, which='sin')
                        cosphi, sinphi = fourier_feature(phi, L=4, which='all'), fourier_feature(phi, L=4, which='all')
                        # cosphi, sinphi = fourier_feature(phi, L=3, which='cos'), fourier_feature(phi, L=3, which='sin')
                        _cosphi_code, _sinphi_code = cosphi, sinphi
                    else:
                        cosphi, sinphi = torch.cos(phi), torch.sin(phi)
                        _cosphi_code, _sinphi_code = self.cosphi_gamma(cosphi), self.sinphi_beta(sinphi)
                if self.weatherEmbedding:
                    domain_label = domain_vec.nonzero()[:,1]
                    weather_embed_ = self.weather_embed(domain_label).squeeze()
                    cosphi_code = self.cosphiWeather_fcs(torch.cat((_cosphi_code,weather_embed_),dim=-1))
                    sinphi_code = self.sinphiWeather_fcs(torch.cat((_sinphi_code,weather_embed_),dim=-1))
                else:
                    cosphi_code, sinphi_code = _cosphi_code, _sinphi_code

            ## common feature
            x_common = (common_contents_level[f"level_2"] * rgb_w).sum(1) if not onlyContentStyle else common_contents_level[f"level_2"][:,0]
            for idx, l in enumerate(reversed(range(3))):
                style_var_fc_l = getattr(self, f"style_var_fc_{l}")
                style_mean_fc_l = getattr(self, f"style_mean_fc_{l}")

                style_input = style

                if self.timephi:
                    x_common = x_common * style_var_fc_l(cosphi_code) + style_mean_fc_l(sinphi_code)
                else:
                    x_common = x_common * style_var_fc_l(style_input) + style_mean_fc_l(style_input)
                x_common = self.generate_hidden_layers[idx](x_common)

                content_r_l = getattr(self, f"content_r_{l}")
                content_b_l = getattr(self, f"content_b_{l}")

                contents_l = (common_contents_level[f"level_{l}"] * rgb_w).sum(1) if not onlyContentStyle else common_contents_level[f"level_{l}"][:,0]
                x_common = torch.sin(x_common * content_r_l(contents_l) + content_b_l(contents_l))
                if idx != 2:
                    x_common = self.generate_layers[idx](x_common)

            if return_what_style in ['common+spec','both']:
                # z = torch.randn_like(cosphi_code)
                if not self.wo_z:
                    if self.update_z:
                        z = self.z.repeat(cosphi_code.shape[0],1)
                    else:
                        assert z != None
                        z = z.repeat(cosphi_code.shape[0],1)
                    # if self.zInputStyle:
                    #     z = self.style2remain(z)

                    if self.use_fourier_feature:
                        z_cos_code = fourier_feature(_cosphi_code,z,gaussian=True)
                        z_sin_code = fourier_feature(_sinphi_code,z,gaussian=True)
                    elif self.branch2_noPhi:
                        z_cos_code = self.z_cos_mapping(z)
                        z_sin_code = self.z_sin_mapping(z)
                    else:
                        z_cos_code = self.z_cos_mapping(torch.cat((z,_cosphi_code),dim=-1))
                        z_sin_code = self.z_sin_mapping(torch.cat((z,_sinphi_code),dim=-1))
                        
                    if self.weatherEmbedding:
                        z_cos_code = self.cosphiWeather_special_fcs(torch.cat((z_cos_code,weather_embed_),dim=-1))
                        z_sin_code = self.sinphiWeather_special_fcs(torch.cat((z_sin_code,weather_embed_),dim=-1))
                else:
                    z_cos_code, z_sin_code = cosphi_code, sinphi_code
                    z_cos_code = self.z_cos_mapping(z_cos_code)
                    z_sin_code = self.z_sin_mapping(z_sin_code)
                ## special feature
                special_feats_list = []
                if self.STB_ver == 'v2':
                    x_special = (common_contents_level[f"level_2"] * rgb_w).sum(1) if not onlyContentStyle else common_contents_level[f"level_2"][:,0]
                else:
                    x_special = (special_contents_level[f"level_2"] * rgb_w).sum(1) if not onlyContentStyle else special_contents_level[f"level_2"][:,0]
                for idx, l in enumerate(reversed(range(3))):
                    style_var_fc_l = getattr(self, f"styleSpecial_var_fc_{l}")
                    style_mean_fc_l = getattr(self, f"styleSpecial_mean_fc_{l}")

                    style_input = style

                    if self.timephi:
                        x_special = x_special * style_var_fc_l(z_cos_code) + style_mean_fc_l(z_sin_code)
                    else:
                        x_special = x_special * style_var_fc_l(style_input) + style_mean_fc_l(style_input)
                    x_special = self.generate_hidden_layers[idx](x_special)

                    content_r_l = getattr(self, f"contentSpecial_r_{l}")
                    content_b_l = getattr(self, f"contentSpecial_b_{l}")

                    if self.STB_ver == 'v2':
                        contents_l = (common_contents_level[f"level_{l}"] * rgb_w).sum(1) if not onlyContentStyle else common_contents_level[f"level_{l}"][:,0]
                        spec_contents_l = (special_contents_level[f"level_{l}"] * rgb_w).sum(1) if not onlyContentStyle else special_contents_level[f"level_{l}"][:,0]
                        special_feats_list.append(spec_contents_l)
                    else:
                        contents_l = (special_contents_level[f"level_{l}"] * rgb_w).sum(1) if not onlyContentStyle else special_contents_level[f"level_{l}"][:,0]
                        special_feats_list.append(contents_l)
                    x_special = torch.sin(x_special * content_r_l(contents_l) + content_b_l(contents_l))
                    if idx != 2:
                        x_special = self.generate_layers[idx](x_special)

            if return_what_style == 'common':
                x = x_common
                x = F.relu(self.style_rgb_fc1(x))
                x = F.relu(self.style_rgb_fc2(x))
                style_rgb = torch.sigmoid(self.style_rgb_fc3(x))
                style_rgb = style_rgb.reshape(N, S, -1)

            elif return_what_style == 'common+spec':
                if self.STB_ver in ['v1','v2']:
                    combine_w = self.specialFeat2weight(torch.cat(special_feats_list,dim=-1))
                    x = (1-combine_w)*x_common + combine_w*x_special
                else:
                    x = x_common + x_special
                x = F.relu(self.style_rgb_fc1(x))
                x = F.relu(self.style_rgb_fc2(x))
                style_rgb = torch.sigmoid(self.style_rgb_fc3(x))
                style_rgb = style_rgb.reshape(N, S, -1)

            elif return_what_style == 'both':
                if self.STB_ver in ['v1','v2']:
                    combine_w = self.specialFeat2weight(torch.cat(special_feats_list,dim=-1))
                    x = (1-combine_w)*x_common + combine_w*x_special
                else:
                    x = x_common + x_special
                _x_common = F.relu(self.style_rgb_fc1(x_common))
                # _x_common = F.relu(self.style_rgb_fc1(x_special))
                _x_common = F.relu(self.style_rgb_fc2(_x_common))
                style_rgb_com = torch.sigmoid(self.style_rgb_fc3(_x_common))
                style_rgb_com = style_rgb_com.reshape(N, S, -1)

                x = F.relu(self.style_rgb_fc1(x))
                x = F.relu(self.style_rgb_fc2(x))
                style_rgb_combine = torch.sigmoid(self.style_rgb_fc3(x))
                style_rgb_combine = style_rgb_combine.reshape(N, S, -1)

                style_rgb = {'c':style_rgb_com, 'c+s':style_rgb_combine}

        if alpha != None:
            if return_what_style == 'common+spec':
                style_rgb = alpha * style_rgb + (1-alpha) * rgb_sigma[...,:-1]
            elif return_what_style == 'both':
                style_rgb['c+s'] = alpha * style_rgb['c+s'] + (1-alpha) * rgb_sigma[...,:-1]

        if onlyContentStyle:
            return None, style_rgb
        else:
            return rgb_sigma, style_rgb
