import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import (general_conv3d, normalization, prm_generator, prm_fusion,
                    prm_generator_laststage, region_aware_modal_fusion, fusion_postnorm, fusion_prenorm)
from blocks import nchwd2nlc2nchwd, DepthWiseConvBlock, ResBlock, GroupConvBlock, MultiMaskAttentionLayer, MultiMaskCrossBlock
from torch.nn.init import constant_, xavier_uniform_
from mask import mask_gen_fusion

from functools import partial


basic_dims = 8
mlp_dim = 4096
num_heads = 8
depth = 3
num_modals = 4

H = 128
W = 128
D = 128

class MultiCrossToken(nn.Module):
    def __init__(
            self,
            image_h=H,
            image_w=W,
            image_d=D,
            h_stride=16,
            w_stride=16,
            d_stride=16,
            num_layers=2,
            mlp_ratio=4,
            drop_rate=0.1,
            attn_drop_rate=0.0,
            interpolate_mode='trilinear',
            channel=basic_dims*16):
        super(MultiCrossToken, self).__init__()

        self.channels = channel
        self.H = image_h // h_stride
        self.W = image_w // w_stride
        self.D = image_d // d_stride
        self.interpolate_mode = interpolate_mode
        self.layers = nn.ModuleList([
            MultiMaskCrossBlock(feature_channels=self.channels,
                                      num_classes=self.channels,
                                      expand_ratio=mlp_ratio,
                                      drop_rate=drop_rate,
                                      attn_drop_rate=attn_drop_rate,
                                      ffn_feature_maps=i != num_layers - 1,
                                      ) for i in range(num_layers)])

    def forward(self, inputs, kernels, mask):
        feature_maps = inputs
        for layer in self.layers:
            kernels, feature_maps = layer(kernels, feature_maps, mask)

        return kernels

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.e1_c1 = general_conv3d(1, basic_dims, pad_type='reflect')
        self.e1_c2 = general_conv3d(basic_dims, basic_dims, pad_type='reflect')
        self.e1_c3 = general_conv3d(basic_dims, basic_dims, pad_type='reflect')

        self.e2_c1 = general_conv3d(basic_dims, basic_dims*2, stride=2, pad_type='reflect')
        self.e2_c2 = general_conv3d(basic_dims*2, basic_dims*2, pad_type='reflect')
        self.e2_c3 = general_conv3d(basic_dims*2, basic_dims*2, pad_type='reflect')

        self.e3_c1 = general_conv3d(basic_dims*2, basic_dims*4, stride=2, pad_type='reflect')
        self.e3_c2 = general_conv3d(basic_dims*4, basic_dims*4, pad_type='reflect')
        self.e3_c3 = general_conv3d(basic_dims*4, basic_dims*4, pad_type='reflect')

        self.e4_c1 = general_conv3d(basic_dims*4, basic_dims*8, stride=2, pad_type='reflect')
        self.e4_c2 = general_conv3d(basic_dims*8, basic_dims*8, pad_type='reflect')
        self.e4_c3 = general_conv3d(basic_dims*8, basic_dims*8, pad_type='reflect')

        self.e5_c1 = general_conv3d(basic_dims*8, basic_dims*16, stride=2, pad_type='reflect')
        self.e5_c2 = general_conv3d(basic_dims*16, basic_dims*16, pad_type='reflect')
        self.e5_c3 = general_conv3d(basic_dims*16, basic_dims*16, pad_type='reflect')

    def forward(self, x):
        x1 = self.e1_c1(x)
        x1 = x1 + self.e1_c3(self.e1_c2(x1))

        x2 = self.e2_c1(x1)
        x2 = x2 + self.e2_c3(self.e2_c2(x2))

        x3 = self.e3_c1(x2)
        x3 = x3 + self.e3_c3(self.e3_c2(x3))

        x4 = self.e4_c1(x3)
        x4 = x4 + self.e4_c3(self.e4_c2(x4))

        x5 = self.e5_c1(x4)
        x5 = x5 + self.e5_c3(self.e5_c2(x5))

        return x1, x2, x3, x4, x5



class Decoder_sep(nn.Module):
    def __init__(self, num_cls=4):
        super(Decoder_sep, self).__init__()

        self.d4 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.d4_c1 = general_conv3d(basic_dims*16, basic_dims*8, pad_type='reflect')
        self.d4_c2 = general_conv3d(basic_dims*16, basic_dims*8, pad_type='reflect')
        self.d4_out = general_conv3d(basic_dims*8, basic_dims*8, k_size=1, padding=0, pad_type='reflect')

        self.d3 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.d3_c1 = general_conv3d(basic_dims*8, basic_dims*4, pad_type='reflect')
        self.d3_c2 = general_conv3d(basic_dims*8, basic_dims*4, pad_type='reflect')
        self.d3_out = general_conv3d(basic_dims*4, basic_dims*4, k_size=1, padding=0, pad_type='reflect')

        self.d2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.d2_c1 = general_conv3d(basic_dims*4, basic_dims*2, pad_type='reflect')
        self.d2_c2 = general_conv3d(basic_dims*4, basic_dims*2, pad_type='reflect')
        self.d2_out = general_conv3d(basic_dims*2, basic_dims*2, k_size=1, padding=0, pad_type='reflect')

        self.d1 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.d1_c1 = general_conv3d(basic_dims*2, basic_dims, pad_type='reflect')
        self.d1_c2 = general_conv3d(basic_dims*2, basic_dims, pad_type='reflect')
        self.d1_out = general_conv3d(basic_dims, basic_dims, k_size=1, padding=0, pad_type='reflect')

        self.seg_layer = nn.Conv3d(in_channels=basic_dims, out_channels=num_cls, kernel_size=1, stride=1, padding=0, bias=True)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x1, x2, x3, x4, x5):

        de_x5 = self.d4_c1(self.d4(x5))
        cat_x4 = torch.cat((de_x5, x4), dim=1)
        de_x4 = self.d4_out(self.d4_c2(cat_x4))

        de_x4 = self.d3_c1(self.d3(de_x4))
        cat_x3 = torch.cat((de_x4, x3), dim=1)
        de_x3 = self.d3_out(self.d3_c2(cat_x3))

        de_x3 = self.d2_c1(self.d2(de_x3))
        cat_x2 = torch.cat((de_x3, x2), dim=1)
        de_x2 = self.d2_out(self.d2_c2(cat_x2))

        de_x2 = self.d1_c1(self.d1(de_x2))
        cat_x1 = torch.cat((de_x2, x1), dim=1)
        de_x1 = self.d1_out(self.d1_c2(cat_x1))

        logits = self.seg_layer(de_x1)
        pred = self.softmax(logits)

        return pred


class Decoder_fusion(nn.Module):
    def __init__(self, num_cls=4):
        super(Decoder_fusion, self).__init__()

        self.d5_c2 = general_conv3d(basic_dims*32, basic_dims*16, pad_type='reflect')
        self.d5_out = general_conv3d(basic_dims*16, basic_dims*16, k_size=1, padding=0, pad_type='reflect')

        self.CT5 = MultiCrossToken(h_stride=16, w_stride=16, d_stride=16, channel=basic_dims*16)
        self.CT4 = MultiCrossToken(h_stride=8, w_stride=8, d_stride=8, channel=basic_dims*8)

        self.d4_c1 = general_conv3d(basic_dims*16, basic_dims*8, pad_type='reflect')
        self.d4_c2 = general_conv3d(basic_dims*16, basic_dims*8, pad_type='reflect')
        self.d4_out = general_conv3d(basic_dims*8, basic_dims*8, k_size=1, padding=0, pad_type='reflect')

        self.d3_c1 = general_conv3d(basic_dims*8, basic_dims*4, pad_type='reflect')
        self.d3_c2 = general_conv3d(basic_dims*8, basic_dims*4, pad_type='reflect')
        self.d3_out = general_conv3d(basic_dims*4, basic_dims*4, k_size=1, padding=0, pad_type='reflect')

        self.d2_c1 = general_conv3d(basic_dims*4, basic_dims*2, pad_type='reflect')
        self.d2_c2 = general_conv3d(basic_dims*4, basic_dims*2, pad_type='reflect')
        self.d2_out = general_conv3d(basic_dims*2, basic_dims*2, k_size=1, padding=0, pad_type='reflect')

        self.d1_c1 = general_conv3d(basic_dims*2, basic_dims, pad_type='reflect')
        self.d1_c2 = general_conv3d(basic_dims*2, basic_dims, pad_type='reflect')
        self.d1_out = general_conv3d(basic_dims, basic_dims, k_size=1, padding=0, pad_type='reflect')

        self.seg_layer = nn.Conv3d(in_channels=basic_dims, out_channels=num_cls, kernel_size=1, stride=1, padding=0, bias=True)
        self.softmax = nn.Softmax(dim=1)

        self.up2 = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        self.up4 = nn.Upsample(scale_factor=4, mode='trilinear', align_corners=True)
        self.up8 = nn.Upsample(scale_factor=8, mode='trilinear', align_corners=True)
        self.up16 = nn.Upsample(scale_factor=16, mode='trilinear', align_corners=True)

        self.RFM5 = fusion_prenorm(in_channel=basic_dims*16, num_cls=num_cls)
        self.RFM4 = fusion_postnorm(in_channel=basic_dims*8, num_cls=num_cls)
        self.RFM3 = fusion_postnorm(in_channel=basic_dims*4, num_cls=num_cls)
        self.RFM2 = fusion_postnorm(in_channel=basic_dims*2, num_cls=num_cls)
        self.RFM1 = fusion_postnorm(in_channel=basic_dims*1, num_cls=num_cls)

        # self.prm_fusion5 = prm_fusion(in_channel=basic_dims*16, basic_dim=basic_dims, num_cls=num_cls)
        self.prm_fusion5 = prm_fusion(in_channel=basic_dims*16, basic_dim=basic_dims, num_cls=num_cls)
        self.prm_fusion4 = prm_fusion(in_channel=basic_dims*8, basic_dim=basic_dims, num_cls=num_cls)
        self.prm_fusion3 = prm_fusion(in_channel=basic_dims*4, basic_dim=basic_dims, num_cls=num_cls)
        self.prm_fusion2 = prm_fusion(in_channel=basic_dims*2, basic_dim=basic_dims, num_cls=num_cls)
        self.prm_fusion1 = prm_fusion(in_channel=basic_dims*1, basic_dim=basic_dims, num_cls=num_cls)

        self.seg_d4 = nn.Conv3d(in_channels=basic_dims*16, out_channels=num_cls, kernel_size=1, stride=1, padding=0, bias=True)
        self.softmax = nn.Softmax(dim=1)



    def forward(self, dx1, dx2, dx3, dx4, dx5, fusion, mask):


        prm_pred5 = self.prm_fusion5(fusion)
        de_x5 = self.CT5(dx5, fusion, mask)
        de_x5 = torch.cat((de_x5, fusion), dim=1)
        # de_x5 = self.CT5(dx5, fusion, mask)
        # de_x5 = torch.cat((de_x5, test_fused), dim=1)
        de_x5 = self.d5_out(self.d5_c2(de_x5))
        de_x5 = self.d4_c1(self.up2(de_x5))

        prm_pred4 = self.prm_fusion4(de_x5)
        # de_x4 = self.CT4(dx4, de_x5, mask)
        de_x4 = self.RFM4(dx4, mask)
        de_x4 = torch.cat((de_x4, de_x5), dim=1)
        de_x4 = self.d4_out(self.d4_c2(de_x4))
        de_x4 = self.d3_c1(self.up2(de_x4))

        prm_pred3 = self.prm_fusion3(de_x4)
        de_x3 = self.RFM3(dx3, mask)
        de_x3 = torch.cat((de_x3, de_x4), dim=1)
        de_x3 = self.d3_out(self.d3_c2(de_x3))
        de_x3 = self.d2_c1(self.up2(de_x3))

        prm_pred2 = self.prm_fusion2(de_x3)
        de_x2 = self.RFM2(dx2, mask)
        de_x2 = torch.cat((de_x2, de_x3), dim=1)
        de_x2 = self.d2_out(self.d2_c2(de_x2))
        de_x2 = self.d1_c1(self.up2(de_x2))

        prm_pred1 = self.prm_fusion1(de_x2)
        de_x1 = self.RFM1(dx1, mask)
        de_x1 = torch.cat((de_x1, de_x2), dim=1)
        de_x1 = self.d1_out(self.d1_c2(de_x1))

        logits = self.seg_layer(de_x1)
        pred = self.softmax(logits)


    

        return pred, (prm_pred1, self.up2(prm_pred2), self.up4(prm_pred3), self.up8(prm_pred4), self.up16(prm_pred5))
        # return pred, (prm_pred1, self.up2(prm_pred2), self.up4(prm_pred3), self.up8(prm_pred4), self.up16(pred4))



class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))


class PreNormDrop(nn.Module):
    def __init__(self, dim, dropout_rate, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fn = fn

    def forward(self, x):
        return self.dropout(self.fn(self.norm(x)))


class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        return F.gelu(x)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout_rate):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            GELU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(p=dropout_rate),
        )

    def forward(self, x):
        return self.net(x)


class MaskedResidual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, mask):
        y, attn = self.fn(x, mask)
        return y + x, attn


class MaskedPreNormDrop(nn.Module):
    def __init__(self, dim, dropout_rate, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fn = fn

    def forward(self, x, mask):
        x = self.norm(x)
        x, attn = self.fn(x, mask)
        return self.dropout(x), attn


class MaskedAttention(nn.Module):
    def __init__(
        self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0, num_class=4
    ):
        super().__init__()
        self.num_heads = heads
        head_dim = dim // heads
        self.scale = qk_scale or head_dim ** -0.5
        self.num_class = num_class

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout_rate)

    # @get_local('attn')
    def forward(self, x, mask):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        self_mask = mask_gen_fusion(B, self.num_heads, N // (self.num_class+1), self.num_class, mask).cuda(non_blocking=True)
        attn = attn.masked_fill(self_mask==0, float("-inf"))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x, attn


class MaskedTransformer(nn.Module):
    def __init__(self, embedding_dim, depth, heads, mlp_dim, dropout_rate=0.1, n_levels=1, n_points=4):
        super(MaskedTransformer, self).__init__()
        self.cross_attention_list = []
        self.cross_ffn_list = []
        self.depth = depth
        for j in range(self.depth):
            self.cross_attention_list.append(
                MaskedResidual(
                    MaskedPreNormDrop(
                        embedding_dim,
                        dropout_rate,
                        MaskedAttention(embedding_dim, heads=heads, dropout_rate=dropout_rate),
                    )
                )
            )
            self.cross_ffn_list.append(
                Residual(
                    PreNorm(embedding_dim, FeedForward(embedding_dim, mlp_dim, dropout_rate))
                )
            )

        self.cross_attention_list = nn.ModuleList(self.cross_attention_list)
        self.cross_ffn_list = nn.ModuleList(self.cross_ffn_list)


    def forward(self, x, mask):
        attn_list=[]
        for j in range(self.depth):
            x, attn = self.cross_attention_list[j](x, mask)
            attn_list.append(attn.detach())
            x = self.cross_ffn_list[j](x)
        return x, attn_list


class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()

        self.trans_bottle = MaskedTransformer(embedding_dim=basic_dims*16, depth=depth, heads=num_heads, mlp_dim=mlp_dim)
        self.num_cls = num_modals

    def forward(self, x, mask, fusion, pos):
        flair, t1ce, t1, t2 = x
        embed_flair = flair.flatten(2).transpose(1, 2).contiguous()
        embed_t1ce = t1ce.flatten(2).transpose(1, 2).contiguous()
        embed_t1 = t1.flatten(2).transpose(1, 2).contiguous()
        embed_t2 = t2.flatten(2).transpose(1, 2).contiguous()

        embed_cat = torch.cat((embed_flair, embed_t1ce, embed_t1, embed_t2, fusion), dim=1)
        embed_cat_trans, attn = self.trans_bottle(embed_cat, mask)
        flair_trans, t1ce_trans, t1_trans, t2_trans, fusion_trans = torch.chunk(embed_cat_trans, self.num_cls+1, dim=1)

        return flair_trans, t1ce_trans, t1_trans, t2_trans, fusion_trans, attn
    



class ITHPv2_Spatial_Robust(nn.Module):
    def __init__(self, ITHP_args):
        super(ITHPv2_Spatial_Robust, self).__init__()
        self.spatial_shape = (8, 8, 8)
        self.modal_dims = ITHP_args['modal_dims']  # [32, 32, 32, 32]
        self.inter_dim = ITHP_args['inter_dim']
        self.bottleneck_dims = ITHP_args['bottleneck_dims']  # [128, 256]
        self.drop_prob = ITHP_args['drop_prob']
        
        self.stage1_encoder = nn.Sequential(
            nn.Conv3d(256, self.inter_dim, kernel_size=1),
            nn.GELU(),
            nn.Conv3d(self.inter_dim, self.bottleneck_dims[0]*2, kernel_size=1)
        )
        
        self.stage1_decoder = nn.Sequential(
            nn.Conv3d(self.bottleneck_dims[0], self.inter_dim, kernel_size=1),
            nn.GELU(),
            nn.Conv3d(self.inter_dim, 256, kernel_size=1)
        )
        
        self.stage2_encoder = nn.Sequential(
            nn.Conv3d(384, self.inter_dim, kernel_size=1),
            nn.GELU(),
            nn.Conv3d(self.inter_dim, self.bottleneck_dims[1]*2, kernel_size=1)
        )
        
        self.stage2_decoder = nn.Sequential(
            nn.Conv3d(self.bottleneck_dims[1], self.inter_dim, kernel_size=1),
            nn.GELU(),
            nn.Conv3d(self.inter_dim, 512, kernel_size=1)
        )

        modal_dim = [128, 128, 128, 128]
        self.missing_embedding = nn.Embedding(2, 32)  
        
        self.alpha = ITHP_args.get('alpha', 0.7)
        self.beta = ITHP_args.get('beta', 0.3)
        self.criterion = nn.MSELoss()

    def kl_divergence(self, mu, logvar):
        return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def process_missing(self, x, mask):
        B, C = x.shape[0], x.shape[1]

        base_feat = self.missing_embedding(
            torch.zeros(B, device=x.device, dtype=torch.long)  # [B,32]
        )
        missing_feat = nn.Linear(32, C, device=x.device)(base_feat)  # [B,C]
        missing_feat = missing_feat.view(B, C, 1, 1, 1).expand_as(x)  # [B,C,D,H,W]
    

        return torch.where((mask == 0).view(B,1,1,1,1), missing_feat, x)
  
    def forward(self, modalities, missing_mask):
    
        modal1, modal2, modal3, modal4 = modalities
        processed_mods = modalities
        

        stage1_input = torch.cat([modal1, modal2], dim=1)
        stage1_params = self.stage1_encoder(stage1_input)
        mu1, logvar1 = torch.chunk(stage1_params, 2, dim=1)
        

        kl_weight1 = 1.0 - missing_mask[:,:2].float().mean() 
        z1 = self.reparameterize(mu1, logvar1)
        

        recon_34 = self.stage1_decoder(z1)
        target_34 = torch.cat([modal3, modal4], dim=1)
        mask_34 = torch.cat([missing_mask[:,2:], missing_mask[:,2:]], dim=1)

        target_34 = target_34.float() / 255.0  
        recon_34 = recon_34 / 255.0         

        recon_loss1 = self.masked_loss(recon_34, target_34, mask_34)
        kl_loss1 = self.kl_divergence(mu1, logvar1) * kl_weight1
        
   
        stage2_input = torch.cat([z1, modal3, modal4], dim=1)
        stage2_params = self.stage2_encoder(stage2_input)
        mu2, logvar2 = torch.chunk(stage2_params, 2, dim=1)
        
  
        kl_weight2 = missing_mask.float().mean(dim=1)
        z2 = self.reparameterize(mu2, logvar2)
        
     
        recon_all = self.stage2_decoder(z2)
        target_all = torch.cat(processed_mods, dim=1)
     
        target_all = target_all.float() / 255.0  # 归一化到[0,1]
        recon_all = recon_all / 255.0     

        recon_loss2 = self.masked_loss(recon_all, target_all, missing_mask)
        kl_loss2 = (self.kl_divergence(mu2, logvar2) * kl_weight2).mean()
        
        total_loss = (kl_loss1 + self.alpha * recon_loss1) + \
                    (kl_loss2 + self.beta * recon_loss2)
        
        
        return z2, total_loss, (kl_loss1, recon_loss1, kl_loss2, recon_loss2)

    def masked_loss(self, pred, target, mask):

        B, C = pred.shape[0], pred.shape[1]
        

        modal_dims = [int(C/4), int(C/4), int(C/4), int(C/4)]
        

        expanded_mask = mask.repeat_interleave(
            torch.tensor(modal_dims, device=mask.device), 
            dim=1
        )
        

        expanded_mask = expanded_mask.view(B, C, 1, 1, 1)
        
   
        loss = (pred - target).pow(2) * expanded_mask
        
        return loss.sum() / (expanded_mask.sum() + 1e-6)


    

class MaskModal(nn.Module):
    def __init__(self):
        super(MaskModal, self).__init__()
    
    def forward(self, x, mask):
        B, K, C, H, W, Z = x.size()
        y = torch.zeros_like(x)
        y[mask, ...] = x[mask, ...]
        x = y.view(B, -1, H, W, Z)
        return x


class Model(nn.Module):
    def __init__(self, num_cls=4):
        super(Model, self).__init__()
        self.flair_encoder = Encoder()
        self.t1ce_encoder = Encoder()
        self.t1_encoder = Encoder()
        self.t2_encoder = Encoder()
        self.Bottleneck = Bottleneck()
        self.decoder_fusion = Decoder_fusion(num_cls=num_cls)
        self.decoder_sep = Decoder_sep(num_cls=num_cls)

        self.pos = nn.Parameter(torch.zeros(1, ((H//16) * (W//16) * (D//16))*5, basic_dims*16))
        self.fusion = nn.Parameter(nn.init.normal_(torch.zeros(1, ((H//16) * (W//16) * (D//16)), basic_dims*16), mean=0.0, std=1.0))

        self.is_training = False

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                torch.nn.init.kaiming_normal_(m.weight) #

        
        ithp_config = {
            'modal_dims': 128*4, 
            'inter_dim': 256,
            'bottleneck_dims': [128, 256], 
            'drop_prob': 0.1,
            'alpha': 0.7, 
            'beta': 0.3    
        }
     
        self.ithp = ITHPv2_Spatial_Robust(ithp_config)
        self.up_fusion = nn.Conv3d(256, 128, kernel_size=1, padding=0)


        self.masker = MaskModal()
        

    def forward(self, x, mask):


        flair_x1, flair_x2, flair_x3, flair_x4, flair_x5 = self.flair_encoder(x[:, 0:1, :, :, :])
        t1ce_x1, t1ce_x2, t1ce_x3, t1ce_x4, t1ce_x5 = self.t1ce_encoder(x[:, 1:2, :, :, :])
        t1_x1, t1_x2, t1_x3, t1_x4, t1_x5 = self.t1_encoder(x[:, 2:3, :, :, :])
        t2_x1, t2_x2, t2_x3, t2_x4, t2_x5 = self.t2_encoder(x[:, 3:4, :, :, :])

    



        x1 = self.masker(torch.stack((flair_x1, t1ce_x1, t1_x1, t2_x1), dim=1), mask) #Bx4xCxHWZ
        x2 = self.masker(torch.stack((flair_x2, t1ce_x2, t1_x2, t2_x2), dim=1), mask)
        x3 = self.masker(torch.stack((flair_x3, t1ce_x3, t1_x3, t2_x3), dim=1), mask)
        x4 = self.masker(torch.stack((flair_x4, t1ce_x4, t1_x4, t2_x4), dim=1), mask)
        mask_x5 = self.masker(torch.stack((flair_x5, t1ce_x5, t1_x5, t2_x5), dim=1), mask)
        flair_x5, t1ce_x5, t1_x5, t2_x5 = torch.chunk(mask_x5, num_modals, dim=1) 
        flair_x4, t1ce_x4, t1_x4, t2_x4 = torch.chunk(x4, num_modals, dim=1) 
        flair_x3, t1ce_x3, t1_x3, t2_x3 = torch.chunk(x3, num_modals, dim=1) 
        flair_x2, t1ce_x2, t1_x2, t2_x2 = torch.chunk(x2, num_modals, dim=1) 
        flair_x1, t1ce_x1, t1_x1, t2_x1 = torch.chunk(x1, num_modals, dim=1) 


       

        B = x.size(0)
       
        x_bottle = (t1_x5, t2_x5, flair_x5, t1ce_x5)
        
        

        ########################## random shuffle Robust ############################
        modalities_stacked = torch.stack(x_bottle, dim=1)  # shape (B,4,...)
        batch_size = modalities_stacked.size(0)
        device = modalities_stacked.device

       
        new_order = []
        for i in range(batch_size):
            current_mask = mask[i]  
            
     
            present_mask = (current_mask == 1)  
            present_indices = torch.nonzero(present_mask, as_tuple=True)[0]
            absent_indices = torch.nonzero(~present_mask, as_tuple=True)[0]
          
            present_shuffled = present_indices[torch.randperm(len(present_indices), device=device)]
            
    
            absent_shuffled = absent_indices[torch.randperm(len(absent_indices), device=device)]
     
            combined_order = torch.cat([present_shuffled, absent_shuffled])
            new_order.append(combined_order)

        new_order = torch.stack(new_order, dim=0)  

        mask_ordered = torch.gather(mask, dim=1, index=new_order)
        modalities_ordered = torch.gather(
            modalities_stacked,
            dim=1,
            index=new_order.reshape(batch_size,4,1,1,1,1).expand_as(modalities_stacked)
        )


        ordered_bottle = [modalities_ordered[:,i] for i in range(4)]




        fused_feature, total_loss, (kl1, recon1, kl2, recon2) = self.ithp(
            ordered_bottle,
            mask_ordered  
        )



        test_fused = self.up_fusion(fused_feature).permute(0, 2, 3, 4, 1).view(x.size(0), (H//16)*(W//16)*(D//16), basic_dims*16).contiguous()
        flair_trans, t2_trans, t1_trans, t1ce_trans, fusion_trans, attn = self.Bottleneck(x_bottle, mask, test_fused, self.pos)
     

        flair_tra = flair_trans.view(x.size(0), (H//16), (W//16), (D//16), basic_dims*16).permute(0, 4, 1, 2, 3).contiguous()
        t1ce_tra = t1ce_trans.view(x.size(0), (H//16), (W//16), (D//16), basic_dims*16).permute(0, 4, 1, 2, 3).contiguous()
        t1_tra = t1_trans.view(x.size(0), (H//16), (W//16), (D//16), basic_dims*16).permute(0, 4, 1, 2, 3).contiguous()
        t2_tra = t2_trans.view(x.size(0), (H//16), (W//16), (D//16), basic_dims*16).permute(0, 4, 1, 2, 3).contiguous()
        fusion_tra = fusion_trans.view(x.size(0), (H//16), (W//16), (D//16), basic_dims*16).permute(0, 4, 1, 2, 3).contiguous()


        de_x5 = (flair_tra, t1ce_tra, t1_tra, t2_tra)
        de_x4 = (flair_x4, t1ce_x4, t1_x4, t2_x4)
        de_x3 = (flair_x3, t1ce_x3, t1_x3, t2_x3)
        de_x2 = (flair_x2, t1ce_x2, t1_x2, t2_x2)
        de_x1 = (flair_x1, t1ce_x1, t1_x1, t2_x1)


        de_x4 = torch.stack(de_x4, dim=1)
        de_x3 = torch.stack(de_x3, dim=1)
        de_x2 = torch.stack(de_x2, dim=1)
        de_x1 = torch.stack(de_x1, dim=1)

       
        fused_t1 = test_fused.permute(0, 2, 1).view(1,128,8,8,8) + flair_tra
        fused_t2 = test_fused.permute(0, 2, 1).view(1,128,8,8,8) + t2_tra
        fused_t1ce = test_fused.permute(0, 2, 1).view(1,128,8,8,8) + t1ce_tra
        fused_flair = test_fused.permute(0, 2, 1).view(1,128,8,8,8) + flair_tra
        de_x5 = (fused_flair, fused_t1ce, fused_t1, fused_t2)
        fusion_tra = test_fused.permute(0, 2, 1).view(1,128,8,8,8)

        fuse_pred, prm_preds = self.decoder_fusion(de_x1, de_x2, de_x3, de_x4, de_x5, fusion_tra, mask)

        if self.is_training:
            flair_pred = self.decoder_sep(flair_x1, flair_x2, flair_x3, flair_x4, flair_x5)
            t1ce_pred = self.decoder_sep(t1ce_x1, t1ce_x2, t1ce_x3, t1ce_x4, t1ce_x5)
            t1_pred = self.decoder_sep(t1_x1, t1_x2, t1_x3, t1_x4, t1_x5)
            t2_pred = self.decoder_sep(t2_x1, t2_x2, t2_x3, t2_x4, t2_x5)

            return fuse_pred, (flair_pred, t1ce_pred, t1_pred, t2_pred), prm_preds, total_loss, (flair_tra, t1ce_tra, t1_tra, t2_tra), (flair_tra, t1ce_tra, t1_tra, t2_tra, test_fused)   
    
        return fuse_pred



