#    Modified from https://github.com/haotian-liu/LLaVA

import torch
import torch.nn as nn
import re


class IdentityMap(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, *args, **kwargs):
        return x

    @property
    def config(self):
        return {"mm_projector_type": "identity"}


class SimpleResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.pre_norm = nn.LayerNorm(channels)

        self.proj = nn.Sequential(
            nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
        )

    def forward(self, x):
        x = self.pre_norm(x)
        return x + self.proj(x)


def build_vision_projector(config, delay_load=False, **kwargs):
    projector_type = getattr(config, "mm_projector_type", "linear")

    if projector_type == "linear":
        return nn.Linear(config.mm_hidden_size, config.hidden_size)

    mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
    if mlp_gelu_match:
        mlp_depth = int(mlp_gelu_match.group(1))
        modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
        for _ in range(1, mlp_depth):
            modules.append(nn.GELU())
            modules.append(nn.Linear(config.hidden_size, config.hidden_size))
        return nn.Sequential(*modules)

    if projector_type == "identity":
        return IdentityMap()

    if projector_type == "linearclip":
        # load min, max range
        min_max_range = torch.load(config.min_max_range_path)
        assert min_max_range is not None

        class RangeClip(nn.Module):  # actually KNN projector
            def __init__(self, min, max) -> None:
                super().__init__()
                self.register_buffer("min", min.detach().view(1, -1))
                self.register_buffer("max", max.detach().view(1, -1))

            def forward(self, x):
                # dimension broadcast auto done
                return torch.clamp(x, self.min.detach(), self.max.detach())

        return nn.Sequential(
            nn.Linear(config.mm_hidden_size, config.hidden_size),
            RangeClip(*min_max_range),
        )

    raise ValueError(f"Unknown projector type: {projector_type}")
