import transformers
from dataclasses import dataclass
from typing import Callable
from torch import nn, Tensor
from transformers import AutoConfig
from transformers.modeling_outputs import ModelOutput


@dataclass
class RouterWithLLMEmbeddingsOutput(ModelOutput):
    routing_logits: Tensor
    query_embedding: Tensor | None = None


def load_router_with_llm_embeddings(
    pretrained_model_name_or_path: str, n_candidates: int, similarity_function: str, **kwargs
):
    config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
    model_class_name = config.architectures[0]
    if model_class_name == "AutoRouterWithLLMEmbeddings":
        model_class_name = config.original_model_class
    if "For" in model_class_name:
        model_class_name = model_class_name.split("For")[0] + "Model"
    model_class = getattr(transformers, model_class_name)

    class AutoRouterWithLLMEmbeddings(model_class):
        def __init__(self, *args, **kwargs):
            n_candidates = kwargs.pop("n_candidates")
            similarity_function = kwargs.pop("similarity_function")
            super().__init__(*args, **kwargs)
            self.config.n_candidates = n_candidates
            self.config.similarity_function = similarity_function
            self.llm_embeddings = nn.Embedding(self.config.n_candidates, self.config.hidden_size)
            self.compute_similarity: Callable[[Tensor, Tensor], Tensor] = lambda x, y: (
                (x @ y.T) / (x.norm(dim=1).unsqueeze(1) * y.norm(dim=1).unsqueeze(0))
                if self.config.similarity_function == "cos"
                else x @ y.T
            )
            self.config.original_model_class = model_class_name

        def forward(
            self,
            input_ids: Tensor,
            attention_mask: Tensor | None = None,
            **kwargs,
        ) -> RouterWithLLMEmbeddingsOutput:
            outputs = super().forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                **kwargs,
            )
            query_embedding = outputs.last_hidden_state[:, 0, :]
            logits = self.compute_similarity(query_embedding, self.llm_embeddings.weight)
            return RouterWithLLMEmbeddingsOutput(routing_logits=logits, query_embedding=query_embedding)

        def _init_weights(self, module: nn.Module):
            if isinstance(module, nn.Embedding):
                nn.init.normal_(module.weight, mean=0, std=0.78)
            else:
                super()._init_weights(module)

    return AutoRouterWithLLMEmbeddings.from_pretrained(
        pretrained_model_name_or_path, n_candidates=n_candidates, similarity_function=similarity_function, **kwargs
    )
