# SPDX-License-Identifier: Apache-2.0
from typing import Iterable, Optional, Set, Tuple

import torch
from torch import nn
from transformers import ModernBertConfig

from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from .interfaces import SupportsCrossEncoding
from .utils import WeightsMapper, maybe_prefix


class ModernBertEmbeddings(nn.Module):

    def __init__(self, config: ModernBertConfig):

        super().__init__()
        self.config = config
        self.tok_embeddings = VocabParallelEmbedding(config.vocab_size,
                                                     config.hidden_size)
        self.norm = nn.LayerNorm(config.hidden_size,
                                 eps=config.layer_norm_eps,
                                 bias=config.norm_bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if inputs_embeds:
            return self.norm(inputs_embeds)
        else:
            inputs_embeds = self.tok_embeddings(input_ids)
            embeddings = self.norm(inputs_embeds)
            return embeddings


class ModernBertRotaryEmbedding(RotaryEmbedding):

    def __init__(self, config: ModernBertConfig, head_size: int, dim: int,
                 base: float):
        super().__init__(
            head_size=head_size,
            rotary_dim=dim,
            max_position_embeddings=config.max_position_embeddings,
            base=base,
            is_neox_style=True,
            dtype=torch.float16)
        self.config = config


class ModernBertAttention(nn.Module):

    def __init__(self,
                 config: ModernBertConfig,
                 layer_id: Optional[int] = None):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.layer_id = layer_id
        self.deterministic_flash_attn = config.deterministic_flash_attn
        self.num_heads = config.num_attention_heads
        assert self.num_heads % tp_size == 0
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.all_head_size = self.head_dim * self.num_heads
        self.scaling = self.head_dim**-0.5
        self.Wqkv = QKVParallelLinear(
            config.hidden_size,
            self.head_dim,
            self.num_heads,
            bias=config.attention_bias,
        )

        if layer_id % config.global_attn_every_n_layers != 0:
            self.local_attention = (config.local_attention // 2,
                                    config.local_attention // 2)
        else:
            self.local_attention = (-1, -1)

        rope_theta = config.global_rope_theta
        if self.local_attention != (
                -1, -1) and config.local_rope_theta is not None:
            rope_theta = config.local_rope_theta
        self.rotary_emb = ModernBertRotaryEmbedding(config=config,
                                                    head_size=self.head_dim,
                                                    dim=self.head_dim,
                                                    base=rope_theta)
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
                              prefix=f"{layer_id}.attn",
                              attn_type=AttentionType.ENCODER_ONLY)
        self.Wo = RowParallelLinear(config.hidden_size,
                                    config.hidden_size,
                                    bias=config.attention_bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        qkv, _ = self.Wqkv(hidden_states)
        q, k, v = qkv.split([self.all_head_size] * 3, dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)
        attn_outputs = self.attn(q, k, v)
        hidden_states = attn_outputs
        hidden_states, _ = self.Wo(hidden_states)
        return hidden_states


class ModernBertMLP(nn.Module):

    def __init__(self, config: ModernBertConfig):
        super().__init__()
        self.config = config
        self.Wi = nn.Linear(config.hidden_size,
                            int(config.intermediate_size) * 2,
                            bias=config.mlp_bias)
        self.act = nn.GELU()
        self.Wo = RowParallelLinear(config.intermediate_size,
                                    config.hidden_size,
                                    bias=config.mlp_bias)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
        return self.Wo(self.act(input) * gate)[0]


class ModernBertLayer(nn.Module):

    def __init__(self,
                 config: ModernBertConfig,
                 prefix: str = "",
                 layer_id: Optional[int] = None):
        super().__init__()
        self.config = config
        if layer_id == 0:
            self.attn_norm = nn.Identity()
        else:
            self.attn_norm = nn.LayerNorm(config.hidden_size,
                                          eps=config.norm_eps,
                                          bias=config.norm_bias)
        self.attn = ModernBertAttention(config=config, layer_id=layer_id)
        self.mlp_norm = nn.LayerNorm(config.hidden_size,
                                     eps=config.norm_eps,
                                     bias=config.norm_bias)
        self.mlp = ModernBertMLP(config)

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ):
        attn_outputs = self.attn(self.attn_norm(hidden_states),
                                 position_ids=position_ids)
        hidden_states = hidden_states + attn_outputs
        mlp_output = self.mlp(self.mlp_norm(hidden_states))
        hidden_states = hidden_states + mlp_output
        return hidden_states


class ModernBertEncoderLayer(nn.Module):

    def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.layers = nn.ModuleList([
            ModernBertLayer(config=config, layer_id=layer_id)
            for layer_id in range(config.num_hidden_layers)
        ])

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states, position_ids)
        return hidden_states


@support_torch_compile
class ModernBertModel(nn.Module):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"layers.": "encoder_layer.layers."})

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.embeddings = ModernBertEmbeddings(config)
        self.encoder_layer = ModernBertEncoderLayer(vllm_config)
        self.final_norm = nn.LayerNorm(config.hidden_size,
                                       eps=config.norm_eps,
                                       bias=config.norm_bias)

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        weights = self.hf_to_vllm_mapper.apply(weights)
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            if name.endswith(".bias") and name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.embeddings(input_ids=input_ids,
                                            inputs_embeds=inputs_embeds)

        outputs = self.encoder_layer(
            hidden_states=hidden_states,
            position_ids=position_ids,
        )
        norm_outputs = self.final_norm(outputs)
        return norm_outputs


class ModernBertPooler(nn.Module):

    def __init__(self, config: ModernBertConfig):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size,
                               config.classifier_bias)
        self.act = nn.GELU()
        self.norm = nn.LayerNorm(config.hidden_size,
                                 eps=config.norm_eps,
                                 bias=config.norm_bias)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        pooled_output = hidden_states
        pooled_output = pooled_output.mean(dim=0, keepdim=False)
        pooled_output = self.norm(self.act(self.dense(pooled_output)))
        return pooled_output


class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.config = config
        self.model = ModernBertModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(prefix, "modernbert"))
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self._pooler = CrossEncodingPooler(config, self.classifier,
                                           ModernBertPooler(config))

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

        self_weights = []

        def weight_filter():
            for name, weight in weights:
                if name.startswith("model."):
                    yield name[len("model."):], weight
                else:
                    self_weights.append((name, weight))

        self.model.load_weights(weight_filter())

        params_dict = dict(self.named_parameters())

        for name, loaded_weight in self_weights:
            if name.startswith("classifier"):
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            if name.startswith("head"):
                param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)

    def pooler(
        self,
        hidden_states: torch.Tensor,
        pooling_metadata: PoolingMetadata,
    ) -> Optional[PoolerOutput]:
        return self._pooler(hidden_states, pooling_metadata)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor],
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.model(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            position_ids=positions,
        )
