import torch
import torch.nn as nn
import re
from functools import partial
from torch.nn.init import trunc_normal_, normal_
from torch.nn import functional as F


class IdentityMap(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

    @property
    def config(self):
        return {"mm_projector_type": "identity"}


class SimpleResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.pre_norm = nn.LayerNorm(channels)

        self.proj = nn.Sequential(
            nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
        )

    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)


# ----------------------------------------------------------#
class TokenPacker(nn.Module):
    def __init__(
        self,
        raw_grid=24,
        embed_dim=1024,
        num_heads=1024 // 128,
        kv_dim=1024,
        hidden_size=4096,
        scale_factor=2,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.raw_grid = raw_grid
        self.scale_factor = scale_factor  # downsampling ratio=self.scale_factor **2
        self.grid_size = self.raw_grid // self.scale_factor
        self.num_queries = self.grid_size**2
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        self.q_proj_1 = nn.Linear(kv_dim, embed_dim, bias=False)

        k_modules = [nn.Linear(4096, 1024)]
        for _ in range(1, 2):
            k_modules.append(nn.GELU())
            k_modules.append(nn.Linear(1024, 1024))
        self.k_proj_1 = nn.Sequential(*k_modules)

        v_modules = [nn.Linear(4096, 1024)]
        for _ in range(1, 2):
            v_modules.append(nn.GELU())
            v_modules.append(nn.Linear(1024, 1024))
        self.v_proj_1 = nn.Sequential(*v_modules)

        self.ln_q_1 = norm_layer(embed_dim)
        self.ln_k_1 = norm_layer(embed_dim)
        self.ln_v_1 = norm_layer(embed_dim)

        self.clip_attn = nn.MultiheadAttention(embed_dim, num_heads)

        modules = [nn.Linear(1024, hidden_size)]
        for _ in range(1, 2):
            modules.append(nn.GELU())
            modules.append(nn.Linear(hidden_size, hidden_size))
        self.mlp = nn.Sequential(*modules)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            # trunc_normal_(m.weight, std=0.02)
            normal_(
                m.weight, std=0.02
            )  # avoid " RuntimeError: "erfinv_cuda" not implemented for 'BFloat16' "
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def divide_feature(self, x, kernel_size, token_num, N, c):
        h = w = int(token_num**0.5)

        reshape_x = x.reshape(h, w, N, c).reshape(
            h // kernel_size, kernel_size, w, N, c
        )
        reshape_x = reshape_x.permute(0, 2, 1, 3, 4)
        reshape_x = reshape_x.reshape(
            h // kernel_size, w // kernel_size, kernel_size, kernel_size, N, c
        )
        reshape_x = reshape_x.permute(0, 1, 3, 2, 4, 5).reshape(
            h // kernel_size, w // kernel_size, kernel_size * kernel_size, N, c
        )
        reshape_x = reshape_x.permute(2, 0, 1, 3, 4).reshape(
            kernel_size * kernel_size, -1, c
        )

        return reshape_x

    def forward(self, x, attn_mask=None):

        x_multi = x[1]  # mulit-level
        x = x[0]  # original single-level

        key = self.ln_k_1(self.k_proj_1(x_multi)).permute(1, 0, 2)
        value = self.ln_v_1(self.v_proj_1(x_multi)).permute(1, 0, 2)

        token_num, N, c = key.shape

        q = F.interpolate(
            x.reshape(x.shape[0], 24, 24, -1).float().permute(0, 3, 1, 2),
            size=(self.grid_size, self.grid_size),
            mode="bilinear",
        ).permute(
            0, 2, 3, 1
        )  ## fix
        q = q.reshape(q.shape[0], -1, q.shape[-1]).to(x.dtype)

        query = self.ln_q_1(self.q_proj_1(q)).permute(1, 0, 2)

        reshape_query = self.divide_feature(query, 1, self.num_queries, N, c)
        reshape_key = self.divide_feature(key, self.scale_factor, token_num, N, c)
        reshape_value = self.divide_feature(
            value, self.scale_factor, token_num, N, value.shape[-1]
        )
        # Point2Region Atten
        out = self.clip_attn(
            reshape_query, reshape_key, reshape_value, attn_mask=attn_mask
        )[0]

        x = out
        x = x.reshape(self.num_queries, N, -1)
        x = x.permute(1, 0, 2)

        x = self.mlp(x)
        return x

    def _repeat(self, query, N: int):
        return query.unsqueeze(1).repeat(1, N, 1)


# ----------------------------------------------------------#


def build_vision_projector(config, delay_load=False, **kwargs):
    projector_type = getattr(config, "mm_projector_type", "linear")

    # ----------------------------------------------------------#
    if projector_type == "tokenpacker":
        return TokenPacker(
            hidden_size=config.hidden_size, scale_factor=config.scale_factor
        )
    # ----------------------------------------------------------#

    if projector_type == "linear":
        return nn.Linear(config.mm_hidden_size, config.hidden_size)

    mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
    if mlp_gelu_match:
        mlp_depth = int(mlp_gelu_match.group(1))
        modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        return nn.Sequential(*modules)

    if projector_type == "identity":
        return IdentityMap()

    raise ValueError(f"Unknown projector type: {projector_type}")
