from typing import Sequence, Type, Tuple, Union, List, Optional, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from einops.layers.torch import Rearrange
from einops import rearrange
from collections import deque
import numpy as np
import math


ndims = 3 # H,W,D
att_dtype = torch.float16

# these functions are adopted from timm.
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class timm_DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
        super(timm_DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f'drop_prob={round(self.drop_prob,3):0.3f}'


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor

def timm_trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

from functools import reduce
def prod_func(Vec):
    return reduce( lambda x, y: x*y, Vec ) #  math.prod()


class MLP(nn.Module):
    def __init__(self,
                in_feats,
                MLP_type="basic", # scmlp conv basic
                hid_feats=None,
                out_feats=None,
                kernel_size=3,
                act_name="GELU",
                drop=0.,
                bias=False,
            )->None:
        super(MLP, self).__init__()

        out_feats = out_feats or in_feats
        hid_feats = hid_feats or in_feats

        # perefer using conv (in_feats->out_feats) + 2x Linear (out_feats -> out_feats)
        # rather than conventional  Linear (in_feats -> hid_feats) + Linear (hid_feats -> out_feats)
        # never use squeeze in image2image translation
        if MLP_type.lower()=="scmlp":
            # improved MLP : 3x3conv (spatial) -> eca (channel) -> mlp
            self.net = nn.Sequential(*[
                Rearrange('B h w c -> B c h w'),
                nn.Conv3d(in_feats, out_feats, kernel_size=1, bias=bias),
                nn.BatchNorm3d(out_feats),
                get_activation(act_name),
                Rearrange('B c h w d -> B h w d c'),
                nn.Linear(out_feats, out_feats),
                get_activation(act_name),
                nn.Dropout(drop),
                nn.Linear(out_feats, out_feats),
                get_activation(act_name),
                nn.Dropout(drop),
            ])


        elif MLP_type.lower()=="conv":
            # improved MLP # RVT cvpr2022
            self.net = nn.Sequential(*[
                Rearrange('B h w c -> B c h w'),
                nn.Conv3d(in_feats, hid_feats, kernel_size=1, bias=bias),
                nn.BatchNorm2d(hid_feats),
                get_activation(act_name),
                nn.Dropout(drop),
                nn.Conv3d(hid_feats, hid_feats, kernel_size=kernel_size,
                    padding=int(kernel_size//2), groups=hid_feats, bias=bias),
                nn.BatchNorm2d(hid_feats),
                get_activation(act_name),
                nn.Conv3d(hid_feats, out_feats, kernel_size=1, bias=bias),
                nn.BatchNorm2d(out_feats),
                nn.Dropout(drop),
                Rearrange('B c h w-> B h w c'),
            ])


        elif MLP_type.lower()=="basic":
            self.net = nn.Sequential(*[
                    nn.Linear(in_feats, hid_feats),
                    get_activation(act_name),
                    nn.Dropout(drop),
                    nn.Linear(hid_feats, out_feats),
                    nn.Dropout(drop)
            ])

    def forward(self, x)-> torch.Tensor:
        x = self.net(x)
        return x



class Attention(nn.Module):
    def __init__(self,
                dim,
                num_heads,
                patch_size,
                attention_type = "local",
                qkv_bias=True,
                qk_scale=None,
                attn_drop=0.,
                proj_drop=0.,
            )->None:

        super().__init__()

        if isinstance(patch_size, int):
            patch_size = [patch_size]*ndims
        self.patch_size = patch_size

        self.num_heads = num_heads
        assert dim%num_heads==0, "`dim` must be divisible by `num_heads`"
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5


        # local need bias, but not true about global one
        if attention_type == "local":
            # define a parameter table of relative position bias
            relative_position_bias_table = nn.Parameter(
                    torch.zeros((2 * self.patch_size[0] - 1) * (2 * self.patch_size[1] - 1) * (2 * self.patch_size[2] - 1),
                                self.num_heads)
                    )  # 2*Ww-1 * 2*Wh-1 * 2*Wd-1, nH

            # get pair-wise relative position index for each token inside the window
            coords_d = torch.arange(self.patch_size[0])
            coords_h = torch.arange(self.patch_size[1])
            coords_w = torch.arange(self.patch_size[2])
            coords = torch.stack(torch.meshgrid(coords_w, coords_h, coords_d, indexing="ij"))  # 3, Ww, Wh, Wd
            coords_flatten = torch.flatten(coords, 1)  # 3, Ww*Wh*Wd
            relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 3, Ww*Wh*Wd, Ww*Wh*Wd
            relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Ww*Wh*Wd, Ww*Wh*Wd, 3
            relative_coords[:, :, 0] += self.patch_size[0] - 1  # shift to start from 0
            relative_coords[:, :, 1] += self.patch_size[1] - 1
            relative_coords[:, :, 2] += self.patch_size[2] - 1

            relative_coords[:, :, 0] *= (2 * self.patch_size[1] - 1) * (2 * self.patch_size[2] - 1)
            relative_coords[:, :, 1] *= (2 * self.patch_size[2] - 1)
            relative_position_index = relative_coords.sum(-1)  # Ww*Wh*Wd, Ww*Wh*Wd
            #register_buffer("relative_position_index", relative_position_index)

        self.attention_type = attention_type

        if self.attention_type=="local":
            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        elif self.attention_type=="global":
            self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        #self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, q_ms):
        #x_dtype = x.dtype
        #x = x.type(att_dtype)
        #q_ms = q_ms.type(att_dtype) if q_ms is not None else None
        B_, N, C = x.size()


        if self.attention_type=="local":
            qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            q = q * self.scale
        else:
            B = q_ms.size()[0]

            #print(f'q_ms:{q_ms.size()}')
            kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            k, v = kv[0], kv[1]

            q = k.clone()#.type(att_dtype)
            q_tmp = q_ms.reshape(B, self.num_heads, N, C // self.num_heads)

            div_ = int(B_//B)
            rem_ = B_ - B * div_
            q_tmp = q_tmp.repeat(div_, 1, 1, 1)
            q_tmp = q_tmp.reshape(B * div_, self.num_heads, N, C // self.num_heads)

            index = torch.tensor(range(0,B * div_))
            #print(x.dtype, q.dtype, q_tmp.dtype , 'global:  x, q  q_tmp<<<<<<< ')
            q[index.long(), :, :, :] = q_tmp

            if rem_ > 0:
                index = torch.tensor(range(B * div_,B * div_+rem_))
                q[index.long(), :, :, :] = q_tmp[torch.tensor(range(rem_)), :, :, :]
            q = q * self.scale
            #print(f'B:{B}, N:{N}, C:{C},  B_:{B_}, out:{q.size()}')

        attn = (q @ k.transpose(-2, -1))
        attn = nn.functional.softmax(attn, dim=-1)#, dtype=att_dtype)


        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn


def get_patches(x, patch_size):
    B, H, W, D, C = x.size()
    nh = H/patch_size
    nw = W/patch_size
    nd = D/patch_size

    down_req = (nh-int(nh)) + (nw-int(nw)) + (nd-int(nd))
    if down_req>0:
        new_dims = [int(nh)*patch_size, int(nw)*patch_size, int(nd)*patch_size]
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), new_dims).permute(0, 2, 3, 4, 1)
        B, H, W, D, C = x.size()

    x = x.view(B, H // patch_size, patch_size,
                W // patch_size, patch_size,
                D // patch_size, patch_size,
                C
    )
    windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, patch_size, patch_size, patch_size, C)
    return windows, H, W, D


def get_image(windows, patch_size, Hatt, Watt, Datt, H, W, D):
    B = int(windows.size()[0] / (Hatt * Watt * Datt / patch_size / patch_size / patch_size))
    x = windows.view(B, Hatt // patch_size,
                        Watt // patch_size,
                        Datt // patch_size,
                        patch_size, patch_size, patch_size,
                        -1
    )
    x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hatt, Watt, Datt, -1)

    if H-Hatt + W-Watt + D-Datt>0:
        x = downsampler_fn(x.permute(0, 4, 1, 2, 3), [H, W, D]).permute(0, 2, 3, 4, 1)
    return x



class ViTBlock(nn.Module):
    def __init__(self,
                embedd_dim,
                input_dims,
                num_heads,
                MLP_type,
                patch_size,
                mlp_ratio,
                qkv_bias,
                qk_scale,
                drop,
                attn_drop,
                drop_path,
                act_layer,
                attention_type,
                norm_layer,
                layer_scale,
        )->None:
        super().__init__()
        self.patch_size = patch_size
        #self.new_dims = [patch_size* (d//patch_size) for d in input_dims]
        #self.num_windows = prod_func([d//patch_size for d in self.new_dims])
        self.num_windows = prod_func([d//patch_size for d in input_dims])

        self.norm1 = norm_layer(embedd_dim)
        self.spatialConv = nn.Sequential(*[
                Rearrange("b h w d c -> b c h w d"),
                nn.Conv2d(embedd_dim, embedd_dim, groups=embedd_dim, kernel_size=3, padding=1,
                                 bias=False),
                get_norm('instance', num_features=embedd_dim),
                get_activation(act_layer),
                Rearrange("b c h w d -> b h w d c"),
        ])

        self.attn = Attention(embedd_dim,
                              attention_type=attention_type,
                              num_heads=num_heads,
                              patch_size=patch_size,
                              qkv_bias=qkv_bias,
                              qk_scale=qk_scale,
                              attn_drop=attn_drop,
                              proj_drop=drop,
        )

        self.drop_path = timm_DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(embedd_dim)
        self.mlp = MLP(in_feats=embedd_dim, hid_feats=int(embedd_dim * mlp_ratio),
                        act_name=act_layer, drop=drop,
                        MLP_type=MLP_type
        )

        self.layer_scale = False
        if layer_scale is not None and type(layer_scale) in [int, float]:
            self.layer_scale = True
            self.gamma1 = nn.Parameter(layer_scale * torch.ones(embedd_dim), requires_grad=True)
            self.gamma2 = nn.Parameter(layer_scale * torch.ones(embedd_dim), requires_grad=True)
        else:
            self.gamma1 = 1.0
            self.gamma2 = 1.0

    def forward(self, x, q_ms):
        #x = downsampler_fn(x.permute(0, 4, 1, 2, 3), self.new_dims).permute(0, 2, 3, 4, 1)
        B, H, W, D, C = x.size()
        shortcut = x

        #x =  self.spatialConv(x)

        x = self.norm1(x)
        x_windows, Hatt, Watt, Datt = get_patches(x, self.patch_size)
        x_windows = x_windows.view(-1, self.patch_size ** ndims, C)

        attn_windows, _ = self.attn(x_windows, q_ms)
        x = get_image(attn_windows, self.patch_size, Hatt, Watt, Datt, H, W, D)
        x = shortcut + self.drop_path(self.gamma1 * x)
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x

class PatchEmbed(nn.Module):
    def __init__(self, in_chans=3, out_chans=32,
                 drop_rate= 0,
                 kernel_size=3,
                 stride=1, padding=1,
                 dilation=1, groups=1, bias=False,
        )->None:
        super().__init__()

        Convnd = getattr(nn, "Conv%dd" % ndims)
        self.proj = Convnd(in_channels=in_chans, out_channels=out_chans,
                              kernel_size=kernel_size,
                              stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)

        self.drop = nn.Dropout(p=drop_rate)

    def forward(self, x):
        x = self.drop(self.proj(x))
        return x

    

class ViTLayer(nn.Module):
    def __init__(self,
                attention_type,
                dim,
                dim_out,
                depth,
                input_dims,
                num_heads,
                patch_size,
                MLP_type,
                mlp_ratio,
                qkv_bias,
                qk_scale,
                drop,
                attn_drop,
                drop_path,
                norm_layer,
                norm_type,
                layer_scale,
                act_layer
        )->None:
        super().__init__()
        self.patch_size = patch_size
        embedd_dim = dim
        self.input_dims = input_dims
        self.blocks = nn.ModuleList(
                            [ViTBlock(embedd_dim=dim,
                                    num_heads=num_heads,
                                    MLP_type=MLP_type,
                                    patch_size=patch_size,
                                    mlp_ratio=mlp_ratio,
                                    qkv_bias=qkv_bias,
                                    qk_scale=qk_scale,
                                    attention_type=attention_type,
                                    drop=drop,
                                    attn_drop=attn_drop,
                                    drop_path=drop_path[k] if isinstance(drop_path, list) else drop_path,
                                    act_layer=act_layer,
                                    norm_layer=norm_layer,
                                    layer_scale=layer_scale,
                                    input_dims=input_dims)
                            for k in range(depth)]
                            )

    def forward(self, inp, q_ms, CONCAT_ok:bool):
        x = inp.clone()

        #print('inp, ', x.shape)
        x = rearrange(x, 'b c h w d-> b h w d c')
        if q_ms is not None:
            q_ms = rearrange(q_ms, 'b c h w d-> b h w d c')

        for k, blk in enumerate(self.blocks): # apply depth
            if q_ms is None:
                x = blk(x, None)
            else:
                #print('vit block', k, '\t before get_patches:', q_ms.size(), self.patch_size)
                q_ms_patches, Hatt, Watt, Datt = get_patches(q_ms, self.patch_size)
                q_ms_patches = q_ms_patches.view(-1, self.patch_size ** ndims, x.size()[-1])
                x = blk(x, q_ms_patches)
                #print('\t ', 'inp: ', inp.size(), '| x:', x.size(), '| q_ms:', q_ms.size(), '| q_patch', q_ms_patches.shape)

        x = rearrange(x, 'b h w d c-> b c h w d')
        #print('out, ', x.shape)
        #print()

        # TODO
        if CONCAT_ok:
            x = torch.cat((inp, x), dim=-1)
        else:
            x = inp + x
        return x



class ViT(nn.Module):
    def __init__(self,
                PYR_SCALES=None,
                feats_num=None,
                hid_dim=None,
                depths=None,
                patch_size=None,
                mlp_ratio=None,
                num_heads=None,
                MLP_type=None,
                norm_type=None,
                act_layer=None,
                drop_path_rate:float=0.2,
                qkv_bias:bool=True,
                qk_scale:bool=None,
                drop_rate:float=0.,
                attn_drop_rate:float=0.,
                norm_layer=nn.LayerNorm,
                layer_scale=None,
                img_size=None):
        super().__init__()

        num_levels = len(feats_num)
        patch_size = patch_size if isinstance(patch_size, list) else [patch_size for _ in range(num_levels)]
        hwd = img_size[-1]

        self.patch_embed = nn.ModuleList(
                            [PatchEmbed(
                                    in_chans=feats_num[i],
                                        out_chans=hid_dim,
                                        drop_rate=drop_rate
                                    )
                                    for i in range(num_levels)]
        )
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        self.levels = nn.ModuleList()


        for i in range(num_levels):
            #print(f'i:{i},\n  feats_num[i]:{feats_num[i]}\n'+
            #      f'depths[i]:{depths[i]},\n num_heads[i]:{num_heads[i]},\n' +
            #      f'dpr[i]: {dpr[sum(depths[:i]):sum(depths[:i + 1])]}' +
            #      f'patch_size[i]:{patch_size[i]},\n img_size[i]:{img_size[i]},\n')
            level = ViTLayer(dim=hid_dim,
                            dim_out=hid_dim,
                            depth=depths[i],
                            num_heads=num_heads[i],
                            patch_size=patch_size[i],
                            MLP_type=MLP_type,
                            attention_type =  "local" if i == 0 else "global",
                            drop_path=dpr[sum(depths[:i]):sum(depths[:i+1])],
                            input_dims=img_size[i],
                            mlp_ratio=mlp_ratio,
                            qkv_bias=qkv_bias,
                            qk_scale=qk_scale,
                            drop=drop_rate,
                            attn_drop=attn_drop_rate,
                            norm_layer=norm_layer,
                            layer_scale=layer_scale,
                            norm_type=norm_type,
                            act_layer=act_layer
            )
            self.levels.append(level)
        #self.apply(self._init_weights)


    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            timm_trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'rpb'}

    def forward(self, KQs, CONCAT_ok: bool=False):

        # patch pyramid
        for i, (patch_embed_, level) in enumerate(zip(self.patch_embed, self.levels)):
            #print('\n ViT level', i, '\n input->', KQs[i].size())

            if i == 0:
                K = patch_embed_(KQs[i])
                x = level(K, None, CONCAT_ok=CONCAT_ok)
            else:
                Q = patch_embed_(KQs[i])
                x = level(K, Q, CONCAT_ok=CONCAT_ok)
                #print(f'\t Q->{Q.size()}')

            #print(f'\t K->{K.size()}')
            #print(f'\t x->{x.size()}')

        return x

def get_norm(name, **kwargs):
    if name.lower() == 'BatchNorm'.lower():
        BatchNorm = getattr(nn, 'BatchNorm%dd' % ndims)
        return BatchNorm(**kwargs)
    elif name.lower() in ['instance', 'InstanceNorm'.lower()]:
        InstanceNorm = getattr(nn, 'InstanceNorm%dd' % ndims)
        return InstanceNorm(**kwargs)
    elif name.lower() == 'None'.lower():
        return nn.Identity()
    else:
        return NotImplementedError


def get_activation(name, **kwargs):
    if name.lower() == 'ReLU'.lower():
        return nn.ReLU()
    elif name.lower() == 'GELU'.lower():
        return nn.GELU()
    elif name.lower() == 'None'.lower():
        return nn.Identity()
    else:
        return NotImplementedError


def downsampler_fn(data, out_size):
    """
    input sahep: B,Ci,Hi,Wi,Di
    output sahep: B,C,H,W,D

    """
    out = nn.functional.interpolate(data, 
                                     size=out_size, 
                                     mode='trilinear', 
                                     align_corners=None, 
                                     recompute_scale_factor=None, 
                                     #antialias=False
    )
    return out.to(data.get_device())


class MSAEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        num_stages = int(math.log2(min(config['data_size'])))-1 

        self._num_stages = num_stages
        self._msa = config['enabled']
        self.msa_seg = config['use_msa_seg_loss'] 

        # apply _msa
        if self._msa:
            self.msa_enc = nn.ModuleList()
            img_size = []
            feats_num = []
            
            out_out_channels = torch.full((len(config['out_fmaps']),), int(config['fpn_channels'])).tolist()
            
            for k in range(len(out_out_channels)):
                img_size += [[int(item/2**(self._num_stages-k-1))  for item in config['data_size']]]
                feats_num += [out_out_channels[k]]
                n = len(feats_num)
                #print(f'n: {n},  img_size: {img_size}, feats_num:{feats_num} ')

                if k == 0:
                    self.msa_enc.append(nn.Identity())
                else:
                    self.msa_enc.append(
                        ViT(
                            PYR_SCALES=[1.,],
                            feats_num=feats_num,
                            hid_dim=int(config['fpn_channels']),
                            depths=[int(config['depths'])]*n, 
                            patch_size=[int(config['patch_size'])]*n,
                            mlp_ratio=int(config['mlp_ratio']), 
                            num_heads=[int(config['num_heads'])]*n,
                            MLP_type='basic',
                            norm_type='BatchNorm2d',
                            act_layer='gelu',
                            drop_path_rate=config['drop_path_rate'],
                            qkv_bias=config['qkv_bias'],
                            qk_scale=None,
                            drop_rate=config['drop_rate'],
                            attn_drop_rate=config['attn_drop_rate'],
                            norm_layer=nn.LayerNorm,
                            layer_scale=1e-5,
                            img_size=img_size
                    ))
        if  self._msa and self.msa_seg:
            self._seg_head = nn.ModuleList()
            for k in range(len(out_out_channels)):
                out_channels = config['num_organs'] + 1
                self._seg_head.append(nn.Conv3d(out_out_channels[k], out_channels, kernel_size=1, stride=1))


    def forward(self, cnn_outputs):
        if self._msa: 
            xs = []
            for key in range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1) :
                xs.append(cnn_outputs[key].clone())

            out_dict = {}
            for i, key in enumerate(range(max(cnn_outputs.keys()), min(cnn_outputs.keys())-1, -1)):
                QK = xs[0:i+1]
                QK.reverse()

                if i == 0:
                    Pi = QK[0]
                else:
                    Pi = self.msa_enc[i](QK)
                out_dict.update({'P' + str(key): Pi})
                
                # get segmentation map
                if  self.msa_seg:
                    Pi_seg = self._seg_head[i](Pi)
                    out_dict.update({'S' + str(key): Pi_seg})
        else:
            out_dict = {}
            for k, v in cnn_outputs.items():
                out_dict.update({'P' + str(k): v})

        return out_dict
