# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from typing import Optional

import torch
from torch import nn
from transformers import ModernBertConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.pooler import ClassifierPooler
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsCrossEncoding
from .utils import WeightsMapper, maybe_prefix


class ModernBertEmbeddings(nn.Module):

    def __init__(self, config: ModernBertConfig):

        super().__init__()
        self.config = config
        self.tok_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                     config.hidden_size)
        self.norm = nn.LayerNorm(config.hidden_size,
                                 eps=config.layer_norm_eps,
                                 bias=config.norm_bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if inputs_embeds:
            return self.norm(inputs_embeds)
        else:
            inputs_embeds = self.tok_embeddings(input_ids)
            embeddings = self.norm(inputs_embeds)
            return embeddings


class ModernBertRotaryEmbedding(RotaryEmbedding):

    def __init__(self, config: ModernBertConfig, head_size: int, dim: int,
                 base: float):
        super().__init__(
            head_size=head_size,
            rotary_dim=dim,
            max_position_embeddings=config.max_position_embeddings,
            base=base,
            is_neox_style=True,
            dtype=torch.float16)
        self.config = config


class ModernBertAttention(nn.Module):

    def __init__(self,
                 config: ModernBertConfig,
                 layer_id: Optional[int] = None):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.layer_id = layer_id
        self.deterministic_flash_attn = config.deterministic_flash_attn
        self.num_heads = config.num_attention_heads
        assert self.num_heads % tp_size == 0
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.all_head_size = self.head_dim * self.num_heads
        self.scaling = self.head_dim**-0.5
        self.Wqkv = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.num_heads,
            bias=config.attention_bias,
        )

        if layer_id % config.global_attn_every_n_layers != 0:
            self.local_attention = (config.local_attention // 2,
                                    config.local_attention // 2)
        else:
            self.local_attention = (-1, -1)

        rope_theta = config.global_rope_theta
        if self.local_attention != (
                -1, -1) and config.local_rope_theta is not None:
            rope_theta = config.local_rope_theta
        self.rotary_emb = ModernBertRotaryEmbedding(config=config,
                                                    head_size=self.head_dim,
                                                    dim=self.head_dim,
                                                    base=rope_theta)
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              prefix=f"{layer_id}.attn",
                              attn_type=AttentionType.ENCODER_ONLY)
        self.Wo = RowParallelLinear(config.hidden_size,
                                    config.hidden_size,
                                    bias=config.attention_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        qkv, _ = self.Wqkv(hidden_states)
        q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        attn_outputs = self.attn(q, k, v)
        hidden_states = attn_outputs
        hidden_states, _ = self.Wo(hidden_states)
        return hidden_states


class ModernBertMLP(nn.Module):

    def __init__(self, config: ModernBertConfig):
        super().__init__()
        self.config = config
        self.Wi = nn.Linear(config.hidden_size,
                            int(config.intermediate_size) * 2,
                            bias=config.mlp_bias)
        self.act = nn.GELU()
        self.Wo = RowParallelLinear(config.intermediate_size,
                                    config.hidden_size,
                                    bias=config.mlp_bias)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
        return self.Wo(self.act(input) * gate)[0]


class ModernBertLayer(nn.Module):

    def __init__(self,
                 config: ModernBertConfig,
                 prefix: str = "",
                 layer_id: Optional[int] = None):
        super().__init__()
        self.config = config
        if layer_id == 0:
            self.attn_norm = nn.Identity()
        else:
            self.attn_norm = nn.LayerNorm(config.hidden_size,
                                          eps=config.norm_eps,
                                          bias=config.norm_bias)
        self.attn = ModernBertAttention(config=config, layer_id=layer_id)
        self.mlp_norm = nn.LayerNorm(config.hidden_size,
                                     eps=config.norm_eps,
                                     bias=config.norm_bias)
        self.mlp = ModernBertMLP(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ):
        attn_outputs = self.attn(self.attn_norm(hidden_states),
                                 position_ids=position_ids)
        hidden_states = hidden_states + attn_outputs
        mlp_output = self.mlp(self.mlp_norm(hidden_states))
        hidden_states = hidden_states + mlp_output
        return hidden_states


class ModernBertEncoderLayer(nn.Module):

    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layers = nn.ModuleList([
            ModernBertLayer(config=config, layer_id=layer_id)
            for layer_id in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states, position_ids)
        return hidden_states


@support_torch_compile
class ModernBertModel(nn.Module):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"layers.": "encoder_layer.layers."})

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.embeddings = ModernBertEmbeddings(config)
        self.encoder_layer = ModernBertEncoderLayer(vllm_config)
        self.final_norm = nn.LayerNorm(config.hidden_size,
                                       eps=config.norm_eps,
                                       bias=config.norm_bias)

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        weights = self.hf_to_vllm_mapper.apply(weights)
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if name.endswith(".bias") and name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        positions: Optional[torch.Tensor] = None,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        position_ids = positions if positions is not None else position_ids
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embeddings(input_ids=input_ids,
                                            inputs_embeds=inputs_embeds)

        outputs = self.encoder_layer(
            hidden_states=hidden_states,
            position_ids=position_ids,
        )
        norm_outputs = self.final_norm(outputs)
        return norm_outputs


class ModernBertPooler(nn.Module):

    def __init__(self, config: ModernBertConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size,
                               config.classifier_bias)
        self.act = nn.GELU()
        self.norm = nn.LayerNorm(config.hidden_size,
                                 eps=config.norm_eps,
                                 bias=config.norm_bias)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        pooled_output = hidden_states
        pooled_output = pooled_output.mean(dim=0, keepdim=False)
        pooled_output = self.norm(self.act(self.dense(pooled_output)))
        return pooled_output


class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.model = ModernBertModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "modernbert"))
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self._pooler = ClassifierPooler(vllm_config.model_config,
                                        self.classifier,
                                        ModernBertPooler(config))

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):

        self_weights = []

        def weight_filter():
            for name, weight in weights:
                if name.startswith("model."):
                    yield name[len("model."):], weight
                else:
                    self_weights.append((name, weight))

        self.model.load_weights(weight_filter())

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in self_weights:
            if name.startswith("classifier"):
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            if name.startswith("head"):
                param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            position_ids=positions,
        )
