import torch
from torch import nn, Tensor
from mmcv.cnn import ConvModule
import torch.nn.functional as F
import torch.nn.init as init
from typing import List, Tuple, Union, Dict, Optional
from ..builder import HEADS
from .decode_head import BaseDecodeHead

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class Fusion_block(nn.Module):
    def __init__(
            self,
            inp: int,
            oup: int,
            embed_dim: int,
            norm_cfg=dict(type='BN', requires_grad=True),
    ) -> None:
        super(Fusion_block, self).__init__()
        self.norm_cfg = norm_cfg
        self.local_embedding = ConvModule(inp, embed_dim, kernel_size=1, norm_cfg=self.norm_cfg, act_cfg=None)
        self.global_act = ConvModule(oup, embed_dim, kernel_size=1, norm_cfg=self.norm_cfg, act_cfg=None)
        self.act = h_sigmoid()

    def forward(self, x_l, x_g):
        '''
        x_g: global features
        x_l: local features
        '''
        B, C, H, W = x_l.shape
        B, C_c, H_c, W_c = x_g.shape

        local_feat = self.local_embedding(x_l)
        global_act = self.global_act(x_g)
        sig_act = F.interpolate(self.act(global_act), size=(H, W), mode='bilinear', align_corners=False)
        out = local_feat * sig_act
        return out


class MultiFusionAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        window_lengths: List[Union[int, tuple]] = [16, 32],
        dilated_ratios: List[int] = [1, 2],
        qk_head_dim: int = 64,
        v_head_dim: int = 64,
        attn_drop: float = 0.0,
        ):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = dim
        self.qk_head_dim = qk_head_dim
        self.v_head_dim = v_head_dim
        self.qk_dim = num_heads * qk_head_dim
        self.v_dim = num_heads * v_head_dim
        self.scale = qk_head_dim ** -0.5
        
        self.window_lengths = window_lengths
        self.dilated_ratios = dilated_ratios
        
        self.wq = ConvModule(in_channels=dim, out_channels=self.qk_dim, kernel_size=1, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=None)
        self.wk = ConvModule(in_channels=dim, out_channels=self.qk_dim, kernel_size=1, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=None)
        self.wv = ConvModule(in_channels=dim, out_channels=self.v_dim, kernel_size=1, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=None)
        self.wo = ConvModule(in_channels=self.v_dim, out_channels=dim, kernel_size=1, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=None)
        
        self.dw_conv = ConvModule(in_channels=self.qk_dim * 2 + self.v_dim, out_channels=self.qk_dim * 2 + self.v_dim, kernel_size=3, stride=1, padding=1, groups=dim, norm_cfg=dict(type='BN', requires_grad=True), act_cfg=None)
        self.pw_conv = ConvModule(in_channels=self.qk_dim * 2 + self.v_dim, out_channels=num_heads, kernel_size=1, stride=1, padding=0, act_cfg=None)
        
        self.attn_drop = nn.Dropout(attn_drop)
        
        self.act = nn.ReLU6(inplace=True)
        self.sigmoid = h_sigmoid()
        
    def window_partition(self, x: Tensor, wl: Tuple[int, int]):
        # B, heads, d, H, W
        B, h, d, H, W = x.shape
        windows = x.view(B, h, d, H // wl[0], wl[0], W // wl[1], wl[1])
        windows = windows.permute(0, 3, 5, 1, 4, 6, 2).contiguous().view(-1, h, wl[0], wl[1], d)
        
        return windows
    
    def window_reverse(self, 
                       x: Tensor, 
                       wl: Tuple[int, int], 
                       bsz: int, 
                       input_res: Tuple[int, int]):
        B_, h, wh, ww, d = x.shape
        H, W = input_res
        x = x.view(bsz, -1, wh, ww, d)
        num_h = H // wl[0]
        num_w = W // wl[1]
        x = x.view(bsz, num_h, num_w, h, wh, ww, d)
        x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(bsz, h, H, W, d)
        return x
    
        
    def gathering(self, 
                  x: Tensor, 
                  dr: int, 
                  wl: Tuple[int, int], 
                  kernel_norm: Tensor):
        B, h, d, H, W = x.shape
        if kernel_norm is not None:
            x = x * kernel_norm.unsqueeze(2)
        
        x_windows = self.window_partition(x, wl)
        
        if dr > 1:
            x_windows = x_windows.view(-1, h, wl[0] // dr, dr, wl[1] // dr, dr, d)
            B_ = x_windows.shape[0]
            x_windows = x_windows.permute(0, 1, 2, 4, 3, 5, 6).contiguous().view(B_, h, -1, dr * dr, d).mean(dim=-2, keepdim=False)
        else:
            x_windows = x_windows.flatten(2, 3)
        
        return x_windows
    
    
    def scattering(self,
                   outs: List[Tensor],
                   lses: List[Tensor],
                   bsz: int,
                   window_lengths: List[Tuple[int]],
                   input_res: Tuple[int, int],
                   kernel_norm: Optional[Tensor] = None):
        assert len(outs) == len(window_lengths) == len(lses)
        all_outs, all_lses = [], []
        
        for idx, (o, lse) in enumerate(zip(outs, lses)):
            dr = self.dilated_ratios[idx]
            wl = window_lengths[idx]
            
            B_, num_heads, fused_len, d = o.shape
            fused_wh, fused_ww = wl[0] // dr, wl[1] // dr
            assert fused_len == fused_wh * fused_ww
            if dr > 1:
                o = o.view(B_, num_heads, fused_wh, 1, fused_ww, 1, d)
                o = o.expand(-1, -1, -1, dr, -1, dr, -1).contiguous()
                o = o.view(B_, num_heads, wl[0], wl[1], d)
                
                lse = lse.view(B_, num_heads, fused_wh, 1, fused_ww, 1)
                lse = lse.expand(-1, -1, -1, dr, -1, dr).contiguous()
                lse = lse.view(B_, num_heads, wl[0], wl[1], 1)
            else:
                o = o.view(B_, num_heads, wl[0], wl[1], d)
                lse = lse.view(B_, num_heads, wl[0], wl[1], 1)
            
            o = self.window_reverse(o, wl, bsz, input_res)
            lse = self.window_reverse(lse, wl, bsz, input_res)
            
            if dr > 1:
                o = o * kernel_norm.unsqueeze(-1)
                lse = lse * kernel_norm.unsqueeze(-1)

            all_outs.append(o)
            all_lses.append(lse)
            
        out = 0
        if len(all_outs) > 1:
            with torch.no_grad():
                max_lse = torch.stack(all_lses, dim=0).max(dim=0)[0]
                exp_lses = [torch.exp(lse - max_lse) for lse in all_lses]
                lse_sum = torch.stack(exp_lses, dim=0).sum(dim=0)
                exp_lses = [lse / lse_sum for lse in exp_lses]

            for idx, o in enumerate(all_outs):
                o = o * exp_lses[idx].type_as(o)
                out = out + o / len(outs)
        else:
            out = all_outs[-1]
        
        out = out.permute(0, 1, 4, 2, 3).contiguous().flatten(1, 2)
        
        return out


    def window_attention_ops(self, 
                             q: Tensor, k: Tensor, v: Tensor,
                             attn_mask: Optional[Tensor] = None):
        B_, n_heads, N, _ = q.shape
        attn = torch.einsum('bhnd,bhmd->bhnm', q, k) * self.scale
        
        if attn_mask is not None:
            # attn_mask: (num_windows, wh * ww, wh * ww)
            num_windows = attn_mask.shape[0]
            attn = attn.view(B_ // num_windows, num_windows, n_heads, N, N)
            attn = attn + attn_mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(B_, n_heads, N, N)
            attn_weights = F.softmax(attn, dim=-1)
        else:
            attn_weights = F.softmax(attn, dim=-1)
        
        attn_probs = self.attn_drop(attn_weights)
        out = torch.einsum('bhnm,bhmd->bhnd', attn_probs, v)
        
        lse = torch.logsumexp(attn_probs, dim=-1)
        
        return out, lse
    
    
    def forward(self, x: Tensor):
        B, C, H, W = x.shape
        
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)
        
        qkv = torch.concat([q, k, v], dim=1)
        qkv = self.act(self.dw_conv(qkv))
        kernel_norm = self.sigmoid(self.pw_conv(qkv))
        
        q = q.view(B, self.num_heads, self.qk_head_dim, H, W)
        k = k.view(B, self.num_heads, self.qk_head_dim, H, W)
        v = v.view(B, self.num_heads, self.v_head_dim, H, W)
        
        outs = []
        lses = []
        kernel_norms = []
        suited_windows = []
        
        for i, (wl, dr) in enumerate(zip(self.window_lengths, self.dilated_ratios)):
            wl = (wl, wl) if isinstance(wl, int) else wl
            wh, ww = wl
            if min(wl) > min((H, W)) or H % wh != 0 or W % ww != 0:
                wl = (H, W)

            suited_windows.append(wl)
            
            qi = self.gathering(q, dr, wl, kernel_norm=kernel_norm if dr >1 else None)
            ki = self.gathering(k, dr, wl, kernel_norm=kernel_norm if dr >1 else None)
            vi = self.gathering(v, dr, wl, kernel_norm=kernel_norm if dr >1 else None)
            
            out, lse = self.window_attention_ops(
                qi, ki, vi
            )
            
            outs.append(out)
            lses.append(lse)
            
            if min(wl) > min((H, W)) or H % wh != 0 or W % ww != 0:
                break
        
        all_out = self.scattering(outs, lses, bsz=B, window_lengths=suited_windows, input_res=(H, W), kernel_norm=kernel_norm)
        
        all_out = self.wo(all_out)
        
        return all_out


@HEADS.register_module()
class LightHead(BaseDecodeHead):
    """
    SEA-Former: Squeeze-enhanced Axial Transformer for Mobile Semantic Segmentation
    """
    def __init__(self, embed_dims, is_dw=False, use_attn = False, **kwargs):
        super(LightHead, self).__init__(input_transform='multiple_select', **kwargs)

        head_channels = self.channels
        in_channels = self.in_channels    
        self.linear_fuse = ConvModule(
            in_channels=head_channels,
            out_channels=head_channels,
            kernel_size=1,
            stride=1,
            groups=head_channels if is_dw else 1,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg
        )
        for i in range(len(embed_dims)):
            fuse = Fusion_block(in_channels[0] if i == 0 else embed_dims[i-1], in_channels[i+1], embed_dim=embed_dims[i], norm_cfg=self.norm_cfg)
            setattr(self, f"fuse{i + 1}", fuse)
        self.embed_dims = embed_dims
        
        self.use_attn = use_attn
        if use_attn:
            self.attn = MultiFusionAttention(
                dim=self.channels,
                num_heads=4,
                window_lengths=[16, 32],
                dilated_ratios=[1, 2],
                qk_head_dim=self.channels // 4,
                v_head_dim=self.channels // 4,
                )
            self.act = h_sigmoid()

    def forward(self, ret: Dict[str, Tensor]):
        img_size = ret['img_size']
        inputs = ret['outs']
        xx = self._transform_inputs(inputs)  
        x_detail = xx[0]
        for i in range(len(self.embed_dims)):
            fuse = getattr(self, f"fuse{i + 1}")
            x_detail = fuse(x_detail, xx[i+1])
        
        if self.use_attn:
            attn_detail = self.attn(x_detail)
            x_detail = self.act(attn_detail) * x_detail
        
        _c = self.linear_fuse(x_detail)
        x = self.cls_seg(_c)
        return x
