import numpy as np

import torch
from torch import nn
from torch.nn import functional
from torch.utils import checkpoint

from timm.models.layers import DropPath

from einops import rearrange

from utils import logs_handler
from models.modules import PatchEmbed3D, PatchMerging, WindowAttention3D, Mlp
from models.ops import window_partition, window_reverse, get_window_size,\
    generate_shifted_window_mask, generate_causal_mask, process_rollout_attention

logger = logs_handler.get_logger('SwinTransformer3D')


class SwinTransformerBlock3D(nn.Module):
    """ Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        window_size (tuple[int]): Window size.
        shift_size (tuple[int]): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, num_heads, window_size=(2, 7, 7), shift_size=(0, 0, 0),
                 mlp_ratio=4., causal=False, qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_checkpoint=False):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.causal = causal
        self.use_checkpoint = use_checkpoint

        assert 0 <= self.shift_size[0] < self.window_size[0], "shift_size must in 0-window_size"
        assert 0 <= self.shift_size[1] < self.window_size[1], "shift_size must in 0-window_size"
        assert 0 <= self.shift_size[2] < self.window_size[2], "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention3D(
            dim, window_size=self.window_size, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=[drop, drop])

    def forward_part1(self, x, return_attention=False):
        B, D, H, W, C = x.shape
        window_size, shift_size = get_window_size((D, H, W), self.window_size, self.shift_size)

        Dp = int(np.ceil(D / window_size[0])) * window_size[0]
        Hp = int(np.ceil(H / window_size[1])) * window_size[1]
        Wp = int(np.ceil(W / window_size[2])) * window_size[2]
        
        x = self.norm1(x)
        # pad feature maps to multiples of window size
        pad_l = pad_t = pad_d0 = 0
        pad_d1 = (window_size[0] - D % window_size[0]) % window_size[0]
        pad_b = (window_size[1] - H % window_size[1]) % window_size[1]
        pad_r = (window_size[2] - W % window_size[2]) % window_size[2]
        x = functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_d0, pad_d1))
        _, Dp, Hp, Wp, _ = x.shape
        
        Wd, Wh, Ww = [min(a, b) for a, b in zip([Dp, Hp, Wp], self.window_size)]
        nWd, nWh, nWw = Dp//Wd, Hp//Wh, Wp//Ww
        
        # cyclic shift
        if any(i > 0 for i in shift_size):
            shifted_x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1], -shift_size[2]), dims=(1, 2, 3))
            cyclic_attn_mask = generate_shifted_window_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
        else:
            shifted_x = x
            cyclic_attn_mask = None
        
        temporal_mask = None
        if self.causal:
            temporal_mask = generate_causal_mask(Wd, device=x.device)
        
        # partition windows
        x_windows = window_partition(shifted_x, window_size)  # B*nW, Wd*Wh*Ww, C
                    
        # W-MSA/SW-MSA
        attn_outputs = self.attn(x_windows, mask=cyclic_attn_mask, temporal_mask=temporal_mask, 
                                 return_weights=return_attention)  # B*nW, Wd*Wh*Ww, C
        x = attn_outputs.get('outputs')
        attn_weights = attn_outputs.get('weights')
        
        # merge windows
        x = x.view(-1, *(window_size + (C,)))
        shifted_x = window_reverse(x, window_size, B, Dp, Hp, Wp)  # B D' H' W' C
        # reverse cyclic shift
        if any(i > 0 for i in shift_size):
            x = torch.roll(shifted_x, shifts=(shift_size[0], shift_size[1], shift_size[2]), dims=(1, 2, 3))
        else:
            x = shifted_x

        if pad_d1 > 0 or pad_r > 0 or pad_b > 0:
            x = x[:, :D, :H, :W, :].contiguous()
        out = {
            'outputs': x, 
            'attentions_weights': attn_weights, 
            'size': (Dp, Hp, Wp),
            'paddings': [(pad_d0, pad_d1), (pad_t, pad_b), (pad_l, pad_r)], 
            'num_windows': (nWd, nWh, nWw), 
            'window_size': (Wd, Wh, Ww)
        }
        return out

    def forward_part2(self, x):
        return self.drop_path(self.mlp(self.norm2(x)))

    def forward(self, x, return_attention=False):
        """ Forward function.
        Args:
            x: Input feature, tensor size (B, D, H, W, C).
        """
        assert (not return_attention) or (not self.use_checkpoint), 'not supported with ckpt'
        
        shortcut = x
        if self.use_checkpoint:
            part1_outs = checkpoint.checkpoint(self.forward_part1, x)
        else:
            part1_outs = self.forward_part1(x, return_attention=return_attention)
        x = part1_outs.pop('outputs')
        x = shortcut + self.drop_path(x)

        if self.use_checkpoint:
            x = x + checkpoint.checkpoint(self.forward_part2, x)
        else:
            x = x + self.forward_part2(x)
        out = {'outputs': x}
        out.update(part1_outs)
        return out


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of feature channels
        depth (int): Depths of this stage.
        num_heads (int): Number of attention head.
        window_size (tuple[int]): Local window size. Default: (1,7,7).
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
    """

    def __init__(self,
                 dim,
                 depth,
                 num_heads,
                 window_size=(1, 7, 7),
                 mlp_ratio=4.,
                 causal=False,
                 qkv_bias=False,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 downsample=None,
                 use_checkpoint=False):
        super().__init__()
        self.window_size = window_size
        self.shift_size = tuple(i // 2 for i in window_size)
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock3D(
                dim=dim,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=(0, 0, 0) if (i % 2 == 0) else self.shift_size,
                mlp_ratio=mlp_ratio,
                causal=causal,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop,
                attn_drop=attn_drop,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                use_checkpoint=use_checkpoint,
            )
            for i in range(depth)])

        self.downsample = downsample
        if self.downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)

    def forward(self, x, rollout_attention=False):
        """ Forward function.
        Args:
            x: Input feature, tensor size (B, C, D, H, W).
        """
        # calculate attention mask for SW-MSA
        B, _, D, H, W = x.shape

        x = rearrange(x, 'b c d h w -> b d h w c')
        attn_rollout, attn_weights = None, None
        for idx, blk in enumerate(self.blocks):
            block_outputs = blk(x, return_attention=rollout_attention)
            x = block_outputs.get('outputs')
            attn_weights = block_outputs.get('attentions_weights')
            if rollout_attention:
                attn_weights = process_rollout_attention(attn_weights)
            if (idx > 0) and (attn_rollout is not None):
                attn_rollout = attn_weights.bmm(attn_rollout)
            else:
                attn_rollout = attn_weights
                
        if rollout_attention:
            Dp, Hp, Wp = block_outputs.get('size')
            padD, padH, padW = block_outputs.get('paddings')
            num_windows = block_outputs.get('num_windows')
            window_size = block_outputs.get('window_size')
            num_tokens = attn_rollout.shape[1]
            attn_rollout = attn_rollout.reshape(B, num_tokens, *num_windows, *window_size)
            attn_rollout = attn_rollout.permute(0, 1, 2, 5, 3, 6, 4, 7).contiguous()
            attn_rollout = attn_rollout.reshape(B, num_tokens, Dp, Hp, Wp)
            attn_rollout = attn_rollout[:,:,padD[0]:Dp-padD[1],padH[0]:Hp-padH[1],padW[0]:Wp-padW[1]]

        x = x.view(B, D, H, W, -1)

        if self.downsample is not None:
            x = self.downsample(x)
        x = rearrange(x, 'b d h w c -> b c d h w')
        out = {'outputs': x, 'attentions_rollout': attn_rollout}
        return out


class SwinTransformer3D(nn.Module):
    """ Swin Transformer backbone.
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        patch_size (int | tuple(int)): Patch size. Default: (4,4,4).
        num_channels (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        depths (tuple[int]): Depths of each Swin Transformer stage.
        num_heads (tuple[int]): Number of attention head of each stage.
        window_size (tuple[int]): Window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
        drop_rate (float): Dropout rate.
        attn_drop_rate (float): Attention dropout rate. Default: 0.
        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
        norm_layer: Normalization layer. Default: nn.LayerNorm.
        patch_norm (bool): If True, add normalization after patch embedding. Default: False.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters.
    """

    def __init__(self,
                 patch_size=(4, 4, 4),
                 num_channels=3,
                 embed_dim=96,
                 depths=None,
                 num_heads=None,
                 window_size=(2, 7, 7),
                 pool_size=(None, 1, 1),
                 mlp_ratio=4.,
                 causal=False,
                 qkv_bias=True,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.2,
                 norm_layer=nn.LayerNorm,
                 patch_norm=False,
                 frozen_stages=-1,
                 use_checkpoint=False):
        super().__init__()

        if num_heads is None:
            num_heads = [3, 6, 12, 24]
        if depths is None:
            depths = [2, 2, 6, 2]
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.frozen_stages = frozen_stages
        self.window_size = window_size
        self.patch_size = patch_size

        # split image into non-overlapping patches
        self.patch_embed = PatchEmbed3D(
            patch_size=patch_size, num_channels=num_channels, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)

        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule

        # build layers
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(
                dim=int(embed_dim * 2 ** i_layer),
                depth=depths[i_layer],
                num_heads=num_heads[i_layer],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                causal=causal,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=drop_rate,
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging if i_layer < self.num_layers - 1 else None,
                use_checkpoint=use_checkpoint)
            self.layers.append(layer)

        self.pool_size = pool_size        
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.out_features = pool_size[1] * pool_size[2] * self.num_features
        self.ada_pool = nn.AdaptiveAvgPool3d(pool_size)

        # add a norm layer for each output
        self.norm = norm_layer(self.num_features)

        self._freeze_stages()

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.patch_embed.eval()
            for param in self.patch_embed.parameters():
                param.requires_grad = False

        if self.frozen_stages >= 1:
            self.pos_drop.eval()
            for i in range(0, self.frozen_stages):
                m = self.layers[i]
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False

    def init_weights(self):
        """Initialize the weights in backbone.
        """
        def _init_weights(m):
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=.02)
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

        self.apply(_init_weights)

    def forward(self, x, rollout_attention=False, return_feature_maps=False):
        """Forward function."""
        x = x.permute((0, 2, 1, 3, 4)).contiguous()
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        attentions_rollouts = []
        for layer in self.layers:
            layer_outputs = layer(x.contiguous(), rollout_attention=rollout_attention)
            x = layer_outputs.get('outputs')
            attentions_rollout = layer_outputs.get('attentions_rollout')
            if attentions_rollout is not None:
                attentions_rollouts.append(attentions_rollout)
        x = rearrange(x, 'n c d h w -> n d h w c')
        x = self.norm(x)
        x = rearrange(x, 'n d h w c -> n c d h w')
        x_pool = self.ada_pool(x)
        x_pool = rearrange(x_pool, 'n c d h w -> n d (c h w)')
        fmaps = None
        if return_feature_maps:
            fmaps = rearrange(x, 'n c d h w -> n d c h w')
        out = {'outputs': x_pool, 'feature_maps': fmaps, 
               'attentions_rollout': attentions_rollouts}          
        return out
    
    def train(self, mode=True):
        """Convert the model into training mode while keep layers frozen."""
        super(SwinTransformer3D, self).train(mode)
        self._freeze_stages()
