import torch
import torch.nn as nn

from torch.nn.functional import pad
from torch.nn.init import trunc_normal_
from einops import rearrange
    
from natten import NeighborhoodAttention2D as NeighborhoodAttention
from natten.functional import natten2dav, natten2dqkrpb

from einops import rearrange

import spconv
import spconv.pytorch
from opencood.models.spconv_utils import ConvertSparseTensor

class CAVNeighborhoodAttention(NeighborhoodAttention):
    def __init__(self, dim, num_heads, kernel_size, dilation, qkv_bias, qk_scale, attn_drop, proj_drop):
        super().__init__(dim, num_heads, kernel_size, dilation, qkv_bias, qk_scale, attn_drop, proj_drop)
    
    def forward(self, x):
        B, L, _, _, _ = x.shape
        x = rearrange(x, 'b l h w c -> (b l) h w c')
        BL, Hp, Wp, C = x.shape
        H, W = int(Hp), int(Wp)
        pad_l = pad_t = pad_r = pad_b = 0
        if H < self.window_size or W < self.window_size:
            pad_l = pad_t = 0
            pad_r = max(0, self.window_size - W)
            pad_b = max(0, self.window_size - H)
            x = pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
            _, H, W, _ = x.shape
        qkv = (
            self.qkv(x)
            .reshape(BL, H, W, 3, self.num_heads, self.head_dim)
            .permute(3, 0, 4, 1, 2, 5)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = natten2dqkrpb(q, k, self.rpb, self.kernel_size, self.dilation)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = natten2dav(attn, v, self.kernel_size, self.dilation)
        x = x.permute(0, 2, 3, 1, 4).reshape(BL, H, W, C)
        if pad_r or pad_b:
            x = x[:, :Hp, :Wp, :]
        x = rearrange(x, '(b l) h w c -> b l h w c', b=B, l=L)
        
        return self.proj_drop(self.proj(x))

def f_drop_path(x, drop_prob : float=0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor()
    output = x.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return f_drop_path(x, self.drop_prob, self.training)
    
class Mlp(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class ConvTokenizer(nn.Module):
    def __init__(self, in_dim, em_dim, embed_type='normal', norm_layer=None):
        super().__init__()

        self.embed_type = embed_type
        
        if embed_type=='normal':
            self.proj = nn.Sequential(
                nn.Conv2d(in_dim, em_dim, kernel_size=3, stride=1, padding=1),
                nn.Conv2d(em_dim, em_dim, kernel_size=3, stride=2, padding=1)
            )

        elif embed_type=='sparse':
            self.proj = spconv.pytorch.SparseSequential(
                spconv.pytorch.SparseConv2d(in_channels=in_dim, out_channels=em_dim, kernel_size=3, 
                                            stride=2, padding=1, indice_key="spconv1"),
                spconv.pytorch.SubMConv2d(in_channels=em_dim, out_channels=em_dim, kernel_size=3, 
                                        stride=1, padding=1, indice_key="submconv1")
            )

        if norm_layer is not None:
            self.norm = norm_layer(em_dim)
        else:
            self.norm = None

    def forward(self, x):
        # B, H, W, C
        B, L, _, _, _ = x.shape
        
        if self.embed_type == 'normal':
            x = rearrange(x, 'b l h w c -> (b l) c h w')
            x = self.proj(x).permute(0, 2, 3, 1)
        
        elif self.embed_type == 'sparse':
            x = rearrange(x, 'b l h w c -> (b l) h w c')
            x = ConvertSparseTensor(x)
            x = self.proj(x).dense().permute(0, 2, 3, 1)

        if self.norm is not None:
            x = self.norm(x)

        x = rearrange(x, '(b l) h w c -> b l h w c', b=B, l=L)
        return x               
    
class NATLayer(nn.Module):
    def __init__(self, dim, num_heads, kernel_size=5, dilation=None, mlp_ratio=1.0,
                 qkv_bias=True, qk_scale=None, drop=0.0,
                 attn_drop=0.0, drop_path=0.0, act_layer=nn.ReLU, 
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads

        self.norm1 = norm_layer(dim)
        self.attn = CAVNeighborhoodAttention(
            dim,
            num_heads=num_heads,
            kernel_size=kernel_size,
            dilation=dilation,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features = dim, hidden_features = int(dim * mlp_ratio), 
                       act_layer = act_layer, drop = drop)
        self.layer_scale = False

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
    
        # B, L, H, W, C = x.shape
        # shortcut = x
        # x = self.norm1(x).view(-1, H, W, C)
        # x = self.attn(x).view(B, L, H, W, C)
        # x = shortcut + self.drop_path(x)
        # x = x + self.drop_path(self.mlp(self.norm2(x)))
        # return x
    
class NATBlock(nn.Module):
    def __init__(self, dim, num_head, kernel, dilation, mlp_ratio, drop_path):
        super().__init__()

        blocks = []

        for i in range(len(dilation)):
            blocks.append(NATLayer(dim = dim, 
                                   num_heads=num_head, 
                                   kernel_size=kernel, 
                                   dilation=dilation[i],
                                   mlp_ratio = mlp_ratio,
                                   qkv_bias=True, 
                                   qk_scale=None, 
                                   drop=0.3, 
                                   attn_drop=0.0, 
                                   drop_path=drop_path, 
                                   norm_layer=nn.LayerNorm))   
        
        self.blocks = nn.ModuleList(blocks)

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x
        

class Spatial_wise_Attention(nn.Module):
    def __init__(self, dim, num_heads, kernel, dilations, mlp_ratio, 
                 embed_type='normal', drop_path=0.3):
        super(Spatial_wise_Attention, self).__init__()

        self.dim = dim
        self.embed_type = embed_type

        self.emb = ConvTokenizer(in_dim=dim, em_dim=dim, embed_type=embed_type, norm_layer=nn.LayerNorm)
        self.deconv = nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2)

        self.nats = NATBlock(dim=dim, num_head=num_heads, kernel=kernel, dilation=dilations, mlp_ratio=mlp_ratio, drop_path=drop_path)

    def forward(self, x):
        x_ = self.emb(x)
        B, L, _, _, _ = x_.shape     
        output = self.nats(x_)

        output = rearrange(output, 'b l h w c -> (b l) c h w')
        output = self.deconv(output).permute(0, 2, 3, 1)
        output = rearrange(output, '(b l) h w c -> b l h w c', b=B, l=L)
        
        return output


