from functools import partial
from typing import List, Tuple, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers.helpers import to_2tuple
from model.ops.modules import MSDeformAttn
from model.module.repvgg import RepVGGBlock


def with_pos_embed(tensor, pos):
    """Add positional embeddings to the input tensor, if provided."""
    return tensor if pos is None else tensor + pos

def get_reference_points(spatial_shapes, valid_ratios):
    reference_points_list = []
    for lvl, (H_, W_) in enumerate(spatial_shapes):
        ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32),
                                      torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32))
        ref_y = ref_y.reshape(-1)[None] / \
                (valid_ratios[:, None, lvl, 1] * H_)
        ref_x = ref_x.reshape(-1)[None] / \
                (valid_ratios[:, None, lvl, 0] * W_)
        ref = torch.stack((ref_x, ref_y), -1)
        reference_points_list.append(ref)
    reference_points = torch.cat(reference_points_list, 1)
    reference_points = reference_points[:, :, None] * valid_ratios[:, None]
    return reference_points


class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """

    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class ConvLayerNorm(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = RepVGGBlock(in_channels, hidden_channels)
        self.conv2 = RepVGGBlock(hidden_channels, out_channels)
        self.norm = nn.LayerNorm(out_channels)
        if in_channels == out_channels:
            self.use_identy = True
        else:
            self.use_identy = False

    def forward(self, x):
        res = x
        x = self.conv2(self.conv1(x))
        if self.use_identy:
            x = x + res
        x = x.permute(0, 2, 3, 1).contiguous()
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2).contiguous()
        return x


class DeformableTransformerEncoderLayer(nn.Module):
    def __init__(self, dim=256, ffn_dim=2048, dropout=0.1, num_levels=3, num_heads=8, num_points=4, activation=nn.GELU):
        super().__init__()

        # self attention
        self.self_attn = MSDeformAttn(dim, num_levels, num_heads, num_points)
        self.norm1 = nn.LayerNorm(dim)

        # ffn
        self.ffn = Mlp(in_features=dim, out_features=dim, hidden_features=ffn_dim,
                       act_layer=activation)
        self.norm2 = nn.LayerNorm(dim)


    def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
        # self attention
        src2 = self.norm1(src)
        src2 = self.self_attn(with_pos_embed(
            src2, pos), reference_points, src2, spatial_shapes, level_start_index, padding_mask)
        src = src + src2
        src3 = self.norm1(src)
        # ffn
        src3 = self.ffn(src3)
        src = src + src3
        return src


class DeformableTransformerEncoder(nn.Module):
    def __init__(self, num_layers, layer, spatial_shapes, valid_ratios, level_start_index, *args, **kwargs) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList(
            [DeformableTransformerEncoderLayer(**layer) for i in range(num_layers)]
        )

        self.spatial_shapes = spatial_shapes
        reference_points = get_reference_points(spatial_shapes, valid_ratios)
        self.reference_points = nn.Parameter(reference_points, requires_grad=False)
        self.level_start_index = level_start_index

    def forward(self, src, pos=None):
        out = src
        for layer in self.layers:
            out = layer(out, pos, self.reference_points, self.spatial_shapes,
                        self.level_start_index)
        return out


class DeformableTransformerDecoderLayer(nn.Module):
    def __init__(self,
                 dim=256,
                 num_heads=8,
                 ffn_dim=2048,
                 dropout=0.0,
                 num_levels=3,
                 num_points=4,
                 activation=nn.GELU,
                 *args,
                 **kwargs) -> None:
        super().__init__()

        # self attention
        self.self_attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, bias=False, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(dim)

        # cross attention
        self.cross_attn = MSDeformAttn(dim, num_levels, num_heads, num_points)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(dim)

        # ffn
        self.ffn = Mlp(in_features=dim, out_features=dim, hidden_features=ffn_dim,
                       act_layer=activation)
        self.norm3 = nn.LayerNorm(dim)



    def forward(self, query, reference_points, src, spatial_shapes, level_start_index, query_pos):
        # self attention
        q = k = v = with_pos_embed(query, query_pos)
        out1 = self.self_attn(q, k, v, need_weights=False)[0]
        query = query + self.dropout1(out1)
        query = self.norm1(query)
        # cross attention
        query = with_pos_embed(query, query_pos)
        out2 = self.cross_attn(
            query, reference_points, src, spatial_shapes, level_start_index)
        query = query + self.dropout2(out2)
        query = self.norm2(query)
        # ffn
        out3 = self.ffn(query)
        query = query + out3
        query = self.norm3(query)
        return query



class DeformableTransformerDecoder(nn.Module):
    def __init__(self, num_layers, layer, spatial_shapes, valid_ratios, level_start_index, *args, **kwargs) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList(
            [DeformableTransformerDecoderLayer(**layer) for _ in range(num_layers)]
        )
        self.spatial_shapes = spatial_shapes
        reference_points = get_reference_points(spatial_shapes, valid_ratios)
        self.reference_points = nn.Parameter(reference_points, requires_grad=False)
        self.level_start_index = level_start_index

    def forward(self, query, src, query_pos):
        out = query
        outputs = []
        for layer in self.layers:
            out = layer(out, self.reference_points, src, self.spatial_shapes, self.level_start_index, query_pos)
            outputs.append(out)
        return outputs

class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim=256, num_heads=8, ffn_dim=2048, dropout=0.1) -> None:
        super().__init__()

        # self attention
        self.self_attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(dim)

        # ffn
        self.linear1 = nn.Linear(dim, ffn_dim)
        self.activation = nn.GELU()
        self.dropout3 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(ffn_dim, dim)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(dim)

    def ffn(self, src):
        src2 = self.linear2(self.dropout3(self.activation(self.linear1(src))))
        src = src + self.dropout4(src2)
        src = self.norm3(src)
        return src

    def forward(self, query):
        # self attention
        out1 = self.self_attn(query)
        query = query + self.dropout1(out1)
        query = self.norm1(query)
        # ffn
        query = self.ffn(query)
        return query


class TransformerEncoder(nn.Module):
    def __init__(self, num_layers, layer) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList(
            [TransformerEncoderLayer(**layer) for _ in range(num_layers)]
        )

    def forward(self, x):
        out = x
        outputs = []
        for layer in self.layers:
            out = layer(out)
            outputs.append(out)
        return outputs


class TransformerDecoderLayer(nn.Module):
    def __init__(self, dim=256, num_heads=8, ffn_dim=2048, dropout=0.1, activation=nn.GELU, ) -> None:
        super().__init__()

        # self attention
        self.cross_attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, bias=False, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(dim)

        # cross attention
        self.ffn = Mlp(in_features=dim, out_features=dim, hidden_features=ffn_dim, act_layer=activation)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, query, src, padding_mask=None):
        # cross attention
        out1 = self.cross_attn(
            query, src, src, need_weights=False)[0]
        query = query + self.dropout1(out1)
        query = self.norm1(query)
        # ffn
        query = query + self.ffn(query)
        query = self.norm2(query)

        return query


class TransformerDecoder(nn.Module):
    def __init__(self, num_layers, layer) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.layers = nn.ModuleList(
            [TransformerDecoderLayer(**layer) for _ in range(num_layers)]
        )

    def forward(self, query, src):
        out = query
        for layer in self.layers:
            out = layer(out, src)
        return out


class ELFDecoderLayer(nn.Module):
    def __init__(self,
                 dim=256,
                 num_heads=8,
                 ffn_dim=2048,
                 dropout=0.0,
                 num_levels=3,
                 num_points=4,
                 activation=nn.GELU,
                 *args,
                 **kwargs) -> None:
        super().__init__()

        # cross attention
        self.cross_attn = nn.MultiheadAttention(
            dim, num_heads, dropout=dropout, bias=False, batch_first=True)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(dim)

        # self attention
        self.self_attn = MSDeformAttn(dim, num_levels, num_heads, num_points)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(dim)

        # ffn
        self.ffn = Mlp(in_features=dim, out_features=dim, hidden_features=ffn_dim,
                       act_layer=activation)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, query, reference_points, src, spatial_shapes, level_start_index, query_pos):
        # cross attention
        out1 = self.cross_attn(query, src, src, need_weights=False)[0]
        query = query + self.dropout1(out1)
        query = self.norm1(query)

        # self attention
        out2 = self.self_attn(query, reference_points, query, spatial_shapes, level_start_index)
        query = query + self.dropout2(out2)
        query = self.norm2(query)
        # ffn
        out3 = self.ffn(query)
        query = query + out3
        query = self.norm3(query)
        return query


class ELFDecoder(nn.Module):
    def __init__(self, num_layers: List, in_channels: List, layer, spatial_shapes, valid_ratios, level_start_index,
                 *args, **kwargs) -> None:
        super().__init__()
        assert len(num_layers) == 3, "The length of num_layers should be 3"
        self.num_layers = num_layers
        self.valid_scales = len(in_channels)
        self.focus_layers = nn.ModuleList(
            [
                nn.Sequential(
                    ConvLayerNorm(in_channels=in_channels[i], hidden_channels=layer['dim'], out_channels=layer['dim']),
                    ConvLayerNorm(in_channels=layer['dim'], hidden_channels=layer['dim'] * 4,
                                  out_channels=layer['dim'])
                )
                for i in range(self.valid_scales)
            ])
        self.inter_layers = TransformerDecoder(num_layers=num_layers[1], layer=layer)
        self.s_layers = nn.ModuleList(
            [ELFDecoderLayer(**layer) for _ in range(num_layers[2])]
        )
        self.spatial_shapes = spatial_shapes
        reference_points = get_reference_points(spatial_shapes, valid_ratios)
        self.reference_points = nn.Parameter(reference_points, requires_grad=False)
        self.level_start_index = level_start_index

    def forward(self, queries: List, src, query_pos):
        for i in range(self.valid_scales):
            queries[i] = self.focus_layers[i](queries[i])
        mask_feature = queries
        query_flatten = torch.cat([q.flatten(2).transpose(1, 2) for q in queries], dim=1)
        query_flatten = self.inter_layers(query_flatten, src)
        for layer in self.s_layers:
            query_flatten = layer(query_flatten, self.reference_points, src, self.spatial_shapes,
                                  self.level_start_index, query_pos)
        return query_flatten, mask_feature
