from __future__ import annotations

import copy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Iterable, Literal, cast, override

import torch
from peft import LoraConfig, get_peft_model
from peft.tuners.lora import Linear as LoraLinear
from peft.tuners.lora import LoraLayer
from sentence_transformers import SentenceTransformer
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Cache,
    GenerationConfig,
    GenerationMixin,
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase,
)
from transformers.convert_graph_to_onnx import ModelOutput
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.idefics.modeling_idefics import freeze_model

from mow.common import defaults
from mow.common.graph import map_observation_to_graph
from mow.modules.mlp import MLP
from mow.modules.routers import GraphRouter, GraphRouterConfig
from mow.modules.utils import (
    LoraLinearPreHook,
    iterate_lora_layers,
    lora_linear_forward,
)

_PRINT_ROUTING_SCORE_TABLE = False


class MoWStrategy(str):
    @staticmethod
    def TOP_K(k: int):
        if k < 1:
            raise ValueError("k must be a positive integer greater than 0.")
        return MoWStrategy(f"top_{k}")

    @staticmethod
    def TOP_K_FLATTEN(k: int):
        if k < 1:
            raise ValueError("k must be a positive integer greater than 0.")
        return MoWStrategy(f"top_{k}_flatten")

    SCALE: MoWStrategy

    @property
    def k_val(self):
        if self == MoWStrategy.SCALE:
            return None
        if self.startswith("top_"):
            try:
                k = int(self.split("_")[1])
                if k < 1:
                    raise ValueError(
                        "k must be a positive integer greater than 0."
                    )
                return k
            except (IndexError, ValueError):
                raise ValueError(
                    "Invalid MoWStrategy format. Expected 'top_k' where k is a positive integer."
                )
        raise ValueError(
            "Invalid MoWStrategy. Supported strategies are 'scale' and 'top_k'."
        )

    @property
    def is_flatten(self):
        return self.endswith("_flatten")


MoWStrategy.SCALE = MoWStrategy("scale")


class MoWConfig(PretrainedConfig):
    model_type = "mow"

    base_model: str = ""
    expert_models: dict[str, Path] = {}
    lora_config: LoraConfig | None = None
    refinement_rate: float = 0.001
    refinement_temperature: float | None = None
    router_model: str | None = None
    router_config: GraphRouterConfig | None = None
    routing_temperature: float = 0.1
    sentence_transformer_model: str = "all-MiniLM-L6-v2"
    shared_expert: int = 0
    similarity_type: Literal["cosine", "euclidean"] = "cosine"
    strategy: MoWStrategy = MoWStrategy.SCALE
    train_targets: list[
        Literal[
            "router",
            "embedding_set",
            "embedding_projection",
            "shared",
            "lora_mlp",
            "base",
        ]
    ] = ["router"]
    use_lora_mlp: bool = False
    use_lora_mlp_bias: bool = False

    @override
    @classmethod
    def from_dict(cls, config_dict: dict[str, Any], **kwargs):
        base_model = config_dict.get("base_model", None)
        if base_model is None:
            raise ValueError("base_model must be specified in the config.")
        expert_models = config_dict.get("expert_models", {})
        if not isinstance(expert_models, dict):
            raise ValueError(
                "expert_models must be a dictionary mapping model names to paths."
            )
        if not expert_models:
            raise ValueError("At least one expert model must be specified.")

        strategy = MoWStrategy(config_dict.get("strategy", "top_1"))
        lora_config = (
            LoraConfig(**config_dict["lora_config"])
            if "lora_config" in config_dict
            and isinstance(config_dict["lora_config"], dict)
            else None
        )
        refinement_rate = config_dict.get("refinement_rate", 0.001)
        refinement_temperature = config_dict.get("refinement_temperature", None)
        router_model = config_dict.get("router_model", None)
        router_config = config_dict.get("router_config", None)
        if router_model is None and router_config is None:
            raise ValueError(
                "Either router_model or router_config must be provided."
            )
        router_config = (
            GraphRouterConfig(**router_config)
            if isinstance(router_config, dict)
            else router_config
        )
        routing_temperature = config_dict.get("routing_temperature", 0.1)
        sentence_transformer_model = config_dict.get(
            "sentence_transformer_model", "all-MiniLM-L6-v2"
        )
        shared_expert = config_dict.get("shared_expert", False)
        similarity_type = config_dict.get("similarity_type", "cosine")
        train_targets = config_dict.get("train_targets", ["router"])
        use_lora_mlp = config_dict.get("use_lora_mlp", False)
        use_lora_mlp_bias = config_dict.get("use_lora_mlp_bias", False)

        config = {
            "base_model": base_model,
            "expert_models": {
                key: Path(value) for key, value in expert_models.items()
            },
            "lora_config": lora_config,
            "refinement_rate": refinement_rate,
            "refinement_temperature": refinement_temperature,
            "router_model": router_model,
            "router_config": router_config,
            "routing_temperature": routing_temperature,
            "sentence_transformer_model": sentence_transformer_model,
            "shared_expert": shared_expert,
            "similarity_type": similarity_type,
            "strategy": strategy,
            "train_targets": train_targets,
            "use_lora_mlp": use_lora_mlp,
            "use_lora_mlp_bias": use_lora_mlp_bias,
        }

        return super().from_dict(config, **kwargs)

    @override
    def to_dict(self) -> dict[str, Any]:
        return {
            "base_model": self.base_model,
            "expert_models": {
                key: str(value) for key, value in self.expert_models.items()
            },
            "lora_config": (
                {
                    key: (
                        list(value)
                        if isinstance(value, Iterable)
                        and not isinstance(value, str)
                        and not isinstance(value, dict)
                        else value
                    )
                    for key, value in self.lora_config.to_dict().items()
                }
                if self.lora_config is not None
                else None
            ),
            "refinement_rate": self.refinement_rate,
            "refinement_temperature": self.refinement_temperature,
            "router_model": self.router_model,
            "router_config": (
                self.router_config.to_dict()
                if self.router_config is not None
                else None
            ),
            "routing_temperature": self.routing_temperature,
            "sentence_transformer_model": self.sentence_transformer_model,
            "shared_expert": self.shared_expert,
            "similarity_type": self.similarity_type,
            "strategy": str(self.strategy),
            "train_targets": self.train_targets,
            "use_lora_mlp": self.use_lora_mlp,
            "use_lora_mlp_bias": self.use_lora_mlp_bias,
        }


@dataclass
class MoWOutput(CausalLMOutputWithPast):
    routing_score: dict[str, dict[str, torch.Tensor]] | None = None


class MoW(PreTrainedModel, GenerationMixin):
    config_class = MoWConfig
    shared_expert_name: str = "shared"

    def __init__(self, config: MoWConfig):
        super().__init__(config)

        LoraLinear.forward = lora_linear_forward

        self.__config = config
        self.__train_targets = config.train_targets
        self.__lora_config = config.lora_config or defaults.default_lora_config
        lora_target_modules = self.__lora_config.target_modules
        if lora_target_modules is None:
            raise ValueError(
                "lora_target_modules must be set in the lora_config for MoW."
            )
        self.__lora_target_modules = list(lora_target_modules)

        self.__st = SentenceTransformer(config.sentence_transformer_model)

        self.__construct_mow()

        if config.router_model:
            self.router = cast(
                dict[str, GraphRouter],
                nn.ModuleDict(
                    {
                        module: GraphRouter.from_pretrained(config.router_model)
                        for module in self.__lora_target_modules
                    }
                ),
            )
            self.__router_config = cast(
                GraphRouterConfig,
                GraphRouterConfig.from_pretrained(config.router_model),
            )
            print(f"🚀 Loaded router model from {config.router_model}")
        else:
            if config.router_config is None:
                raise ValueError(
                    "router_embedding_dimension must be set in the config "
                    "if router_model is not provided."
                )
            self.__router_config = config.router_config
            self.router = cast(
                dict[str, GraphRouter],
                {
                    module: GraphRouter(config=self.__router_config)
                    for module in self.__lora_target_modules
                },
            )

        if (
            "embedding_projection" in config.train_targets
            and not self.__router_config.use_embedding_projection
        ):
            raise ValueError(
                "embedding_projection is in train_targets, but "
                "use_embedding_projection is False in router_config."
            )

        if "shared" in config.train_targets and not config.shared_expert:
            raise ValueError(
                "shared_expert is in train_targets, but shared_expert is False "
                "in config."
            )

        if "lora_mlp" in config.train_targets and not config.use_lora_mlp:
            raise ValueError(
                "lora_mlp is in train_targets, but use_lora_mlp is False in config."
            )

        self.base.eval()
        self.__trainable_modules = list(
            sorted(x for x, _ in self.__get_trainable_modules())
        )
        self.freeze_non_trainable_parameters()

        self.__strategy = config.strategy

        match config.similarity_type:
            case "cosine":
                self.__similarity_fn = partial(torch.cosine_similarity, dim=-1)
            case "euclidean":
                self.__similarity_fn = lambda a, b: 1 - torch.cdist(
                    a, b, p=2.0
                ).diagonal(dim1=-2, dim2=-1)

        self.refinement_rate = config.refinement_rate
        self.refinement_temperature = (
            config.refinement_temperature or config.routing_temperature
        )
        self.__router_backup: dict[str, GraphRouter] | None = None

        self.__experts = list(config.expert_models.keys())

    def __get_trainable_modules(self):
        for name, module in self.named_modules():
            if not isinstance(name, str) or not isinstance(module, nn.Module):
                continue

            if (
                "shared" in self.__train_targets
                and self.shared_expert_name in name
            ):
                yield name, module

            if "expert" in self.__train_targets and any(
                expert_name in name for expert_name in self.config.expert_models
            ):
                yield name, module

            if (
                "router" in self.__train_targets
                and "router" in name
                and "embedding_set" not in name
            ):
                yield name, module

            if (
                "embedding_set" in self.__train_targets
                and "router" in name
                and "embedding_set" in name
            ):
                yield name, module

            if (
                "embedding_projection" in self.__train_targets
                and "router" in name
                and "embedding_projection" in name
            ):
                yield name, module

            if "lora_mlp" in self.__train_targets and "lora_mlp" in name:
                yield name, module

            if (
                "base" in self.__train_targets
                and "base_model" in name
                and "layers" in name
                and "self_attn" not in name
            ):
                yield name, module

    @property
    def trainable_modules(self) -> list[str]:
        """
        Returns a list of names of trainable modules in the model.
        """
        return self.__trainable_modules

    def freeze_router(self):
        """
        Freezes the router model.
        """
        for module in self.router.values():
            module.eval()
        freeze_model(self.router)

    def freeze_non_trainable_parameters(self):
        """
        Unfreezes the shared expert model.
        """
        for name, module in self.named_modules():
            if name in self.trainable_modules:
                module.requires_grad_(True)
            else:
                module.requires_grad_(False)

    @override
    def get_input_embeddings(self) -> nn.Module:
        return self.base.get_input_embeddings()  # type: ignore

    @override
    def set_input_embeddings(self, value: nn.Module) -> None:
        if not isinstance(value, nn.Embedding):
            raise ValueError(
                f"Expected an instance of nn.Embedding, but got {type(value)}"
            )
        self.base.set_input_embeddings(value)  # type: ignore

    @property
    def tokenizer(self) -> PreTrainedTokenizerBase:
        return self.__tokenizer

    @property
    def embedding_dim(self) -> int:
        return self.get_input_embeddings().embedding_dim  # type: ignore

    @property
    def device(self) -> torch.device:
        return self.base.base_model.device  # type: ignore

    @property
    def used_target_modules(self) -> set[str]:
        return self.__used_target_modules

    @property
    def shared_expert_names(self) -> list[str]:
        return [
            f"{self.shared_expert_name}_{i}"
            for i in range(1, self.__config.shared_expert + 1)
        ]

    @property
    def sentence_transformer(self) -> SentenceTransformer:
        return self.__st

    @property
    def experts(self) -> list[str]:
        return self.__experts

    @staticmethod
    def __copy_adapter(
        src_layer: LoraLayer,
        dst_layer: LoraLayer,
        *,
        src_adapter_name: str,
        dst_adapter_name: str,
    ):
        s_layer = src_layer
        d_layer = dst_layer
        s_a_name = src_adapter_name
        d_a_name = dst_adapter_name

        d_layer.r[d_a_name] = s_layer.r[s_a_name]
        d_layer.lora_alpha[d_a_name] = s_layer.lora_alpha[s_a_name]
        lora_dropout_layer = s_layer.lora_dropout[s_a_name]
        d_layer.lora_dropout.update({d_a_name: lora_dropout_layer})

        # Actual trainable parameters
        if s_layer.use_dora[s_a_name]:
            d_layer.lora_magnitude_vector[d_a_name] = (
                s_layer.lora_magnitude_vector[s_a_name]
            )
        d_layer.lora_A[d_a_name] = s_layer.lora_A[s_a_name]
        d_layer.lora_B[d_a_name] = s_layer.lora_B[s_a_name]
        d_layer.lora_bias[d_a_name] = s_layer.lora_bias[s_a_name]
        d_layer.use_dora[d_a_name] = s_layer.use_dora[s_a_name]

        d_layer.scaling[d_a_name] = s_layer.scaling[s_a_name]

        d_layer._move_adapter_to_device_of_base_layer(d_a_name)

        d_layer.set_adapter([*d_layer.active_adapters, d_a_name])

    def delete_adapter(self, adapter_name: str):
        """
        Deletes the specified adapter from the model.
        """
        if adapter_name == self.shared_expert_name:
            raise ValueError(
                "Cannot delete the shared expert adapter. Please use a different name."
            )

        for _, _, lora_layer in self.iterate_lora_layers():
            if adapter_name in lora_layer.active_adapters:
                lora_layer.delete_adapter(adapter_name=adapter_name)
        self.__used_target_modules.discard(adapter_name)

    def __construct_mow(self):
        self.base = AutoModelForCausalLM.from_pretrained(
            self.__config.base_model
        )
        self.base = get_peft_model(self.base, self.__lora_config)
        self.__tokenizer = AutoTokenizer.from_pretrained(
            self.__config.base_model, use_fast=True
        )

        active_adapters = self.base.active_adapters
        if len(active_adapters) > 1:
            raise ValueError(
                "Only one adapter is supported for MoW. Please set `active_adapters` to 1."
            )
        default_expert_name = active_adapters[0]

        self.__used_target_modules = set[str]()
        for name in self.__lora_target_modules:
            for lora_layer in self.iterate_lora_layers(name):
                if name not in self.__used_target_modules:
                    self.__used_target_modules.add(name)

        for name, _, lora_layer in self.iterate_lora_layers():
            if self.__config.shared_expert:
                self.__copy_adapter(
                    src_layer=lora_layer,
                    src_adapter_name=default_expert_name,
                    dst_layer=lora_layer,
                    dst_adapter_name=f"{self.shared_expert_name}_1",
                )
            lora_layer.delete_adapter(adapter_name=default_expert_name)

        if self.__config.shared_expert > 1:
            for i in range(2, self.__config.shared_expert + 1):
                another = AutoModelForCausalLM.from_pretrained(
                    self.__config.base_model
                )
                another = get_peft_model(another, self.__lora_config)
                for (name, _, base_layer), (_, _, target_layer) in zip(
                    self.iterate_lora_layers(),
                    iterate_lora_layers(another, self.__used_target_modules),
                ):
                    self.__copy_adapter(
                        src_layer=target_layer,
                        src_adapter_name=default_expert_name,
                        dst_layer=base_layer,
                        dst_adapter_name=f"{self.shared_expert_name}_{i}",
                    )
                del another

        for key, target in list(self.__config.expert_models.items()):
            expert = AutoModelForCausalLM.from_pretrained(target)
            for (name, _, base_layer), (_, _, target_layer) in zip(
                self.iterate_lora_layers(),
                iterate_lora_layers(expert, self.__used_target_modules),
            ):
                self.__copy_adapter(
                    src_layer=target_layer,
                    src_adapter_name=default_expert_name,
                    dst_layer=base_layer,
                    dst_adapter_name=key,
                )
            del expert

        self.__lora_layer_pre_hooks = dict[tuple[str, int], LoraLinearPreHook]()
        for name, layer_num, lora_layer in self.iterate_lora_layers():
            hook = LoraLinearPreHook(
                layer_num=layer_num,
                shared_expert_names=self.shared_expert_names,
            )
            lora_layer.register_forward_pre_hook(hook)
            self.__lora_layer_pre_hooks[(name, layer_num)] = hook
            if self.__config.use_lora_mlp:
                out_features = cast(int, lora_layer.out_features)
                lora_mlp = MLP(
                    input_dim=out_features,
                    output_dim=out_features,
                    hidden_dim=out_features * 2,
                    use_bias=self.__config.use_lora_mlp_bias,
                )
                setattr(lora_layer, "lora_mlp", lora_mlp)

    def init_embedding_set(self, targets: list[str]):
        num_nodes = 10
        len_context = 5
        embedding_dim = self.__router_config.embed_dim or self.embedding_dim
        for name in targets:
            for router in self.router.values():
                router.update_embedding_set(
                    name=name,
                    hidden_states=torch.randn(
                        num_nodes, embedding_dim, device=self.device
                    ),
                    adjacency_matrix=torch.randint(
                        0, 2, (num_nodes, num_nodes), device=self.device
                    ),
                    context=torch.randn(
                        len_context, embedding_dim, device=self.device
                    ),
                )

    def iterate_lora_layers(self, name: str | None = None):
        if name is not None:
            yield from iterate_lora_layers(self.base, name)
        else:
            yield from iterate_lora_layers(
                self.base, self.__used_target_modules
            )

    def obs_to_graph(self, observation: str, instruction: str):
        nodes, adj_mat, rel = map_observation_to_graph(observation)
        hidden_states = self.__st.encode(nodes, convert_to_tensor=True)
        if rel is not None:
            rel = self.__st.encode(rel, convert_to_tensor=True)
            rel[0, ...] = torch.zeros(
                (1, *rel.shape[1:]), dtype=rel.dtype, device=rel.device
            )
        else:
            rel = None
        context = self.__st.encode([instruction], convert_to_tensor=True)
        return hidden_states, adj_mat, rel, context

    def get_routing_score(
        self,
        hidden_states: torch.Tensor,
        adjacency_matrix: torch.Tensor,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
        keys: list[str] | None = None,
    ) -> dict[str, dict[str, torch.Tensor]]:
        return {
            name: router.get_similarities(
                hidden_states=hidden_states,
                adjacency_matrix=adjacency_matrix,
                relation_matrix=relation_matrix,
                context=context,
                temperature=self.__config.routing_temperature,
                similarity_fn=self.__similarity_fn,
                top_k=self.__strategy.k_val,
                flatten=self.__strategy.is_flatten,
                keys=keys,
            )
            for name, router in self.router.items()
        }

    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        attention_mask: torch.Tensor | None = None,
        position_ids: torch.LongTensor | None = None,
        past_key_values: Cache | list[torch.FloatTensor] | None = None,
        inputs_embeds: torch.FloatTensor | None = None,
        labels: torch.LongTensor | None = None,
        use_cache: bool | None = True,
        output_attentions: bool | None = None,
        output_hidden_states: bool | None = None,
        return_dict: bool | None = None,
        cache_position: torch.LongTensor | None = None,
        hidden_states: torch.Tensor | None = None,
        adjacency_matrix: torch.Tensor | None = None,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
        **model_kwargs,
    ):
        if input_ids is None:
            raise ValueError("You have to specify input_ids")

        if hidden_states is None:
            raise ValueError("You have to specify hidden_states")

        if adjacency_matrix is None:
            raise ValueError("You have to specify adjacency_matrix")

        routing_score = self.get_routing_score(
            hidden_states=hidden_states,
            adjacency_matrix=adjacency_matrix,
            relation_matrix=relation_matrix,
            context=context,
            keys=self.experts,
        )

        for name, layer_num, _ in self.iterate_lora_layers():
            hook = self.__lora_layer_pre_hooks[(name, layer_num)]
            hook.assign_routing_score(
                {
                    adapter: routing_score[name][adapter][..., layer_num]
                    for adapter in routing_score[name]
                }
            )

        if _PRINT_ROUTING_SCORE_TABLE:
            rooms = "18 20 22 24 26 28 29 31 32 34".split()
            print("=" * 100)
            print("    " + " ".join(f"room_{room}" for room in rooms))
            print("-" * 100)
            for i in range(16):
                print(f"{i:2d}  ", end="")
                for room in rooms:
                    print(
                        f"{routing_score["q_proj"][f"room_{room}"][i]:<7.2f} ",
                        end="",
                    )
                print()
            input()

        ret = self.base.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
            **model_kwargs,
        )
        return MoWOutput(
            loss=ret.loss,
            logits=ret.logits,
            past_key_values=ret.past_key_values,
            hidden_states=ret.hidden_states,
            attentions=ret.attentions,
            routing_score=routing_score,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids: torch.Tensor,
        hidden_states: torch.Tensor | None = None,
        adjacency_matrix: torch.Tensor | None = None,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
        past_key_values: Cache | list[torch.FloatTensor] | None = None,
        attention_mask: torch.Tensor | None = None,
        cache_position: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        use_cache: bool = True,
        **kwargs,
    ):
        past_length = 0
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                past_length = (
                    cache_position[0]
                    if cache_position is not None
                    else past_key_values.get_seq_length()
                )
                max_cache_length = (
                    torch.tensor(
                        past_key_values.get_max_length(),
                        device=input_ids.device,
                    )
                    if past_key_values.get_max_length() is not None
                    else None
                )
                cache_length = (
                    past_length
                    if max_cache_length is None
                    else torch.min(max_cache_length, past_length)
                )
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of
            # input_ids, then we are in a setting where some of the inputs are
            # exclusively passed as part of the cache (e.g. when passing
            # input_embeds as input)
            if (
                attention_mask is not None
                and attention_mask.shape[1] > input_ids.shape[1]
            ):
                input_ids = input_ids[
                    :, -(attention_mask.shape[1] - past_length) :
                ]
            # 2 - If the past_length is smaller than input_ids.shape[1], then
            # input_ids holds all input tokens. We can discard input_ids based on
            # the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids[:, past_length:]
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume
            # input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop
            # the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length  # type: ignore
            ):
                attention_mask = attention_mask[:, -max_cache_length:]

        position_ids = kwargs.get("position_ids", None)
        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past_key_values:
                position_ids = position_ids[:, -input_ids.shape[1] :]

        # if `inputs_embeds` are passed, we only want to use them in the 1st
        # generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds.contiguous()}
        else:
            model_inputs = {"input_ids": input_ids.contiguous()}

        input_length = (
            position_ids.shape[-1]
            if position_ids is not None
            else input_ids.shape[-1]
        )
        if cache_position is None:
            cache_position = torch.arange(
                past_length, past_length + input_length, device=input_ids.device  # type: ignore
            )
        elif use_cache:
            cache_position = cache_position[-input_length:]

        model_inputs.update(
            {
                "attention_mask": attention_mask,  # type: ignore
                "position_ids": position_ids,
                "past_key_values": past_key_values,
                "cache_position": cache_position,
                "use_cache": use_cache,
                "hidden_states": hidden_states,
                "adjacency_matrix": adjacency_matrix,
                "relation_matrix": relation_matrix,
                "context": context,
            }
        )

        return model_inputs

    @override
    def _validate_model_kwargs(self, model_kwargs: dict[str, Any]):
        if "routing_score_collector" in model_kwargs:
            kwargs = model_kwargs.copy()
            del kwargs["routing_score_collector"]
            super()._validate_model_kwargs(kwargs)
        else:
            super()._validate_model_kwargs(model_kwargs)

    @override
    def _prepare_generation_config(
        self, generation_config: GenerationConfig | None, **kwargs: dict
    ) -> tuple[GenerationConfig, dict]:
        generation_config, model_kwargs = super()._prepare_generation_config(
            generation_config, **kwargs
        )
        if (
            routing_score_collector := model_kwargs.get(
                "routing_score_collector", None
            )
        ) is not None:
            model_kwargs["routing_score_collector"] = routing_score_collector
        return generation_config, model_kwargs

    @override
    def _update_model_kwargs_for_generation(
        self,
        outputs: ModelOutput,
        model_kwargs: dict[str, Any],
        is_encoder_decoder: bool = False,
        num_new_tokens: int = 1,
    ) -> dict[str, Any]:
        routing_score_collector: list | None = model_kwargs.get(
            "routing_score_collector", None
        )
        if (
            routing_score_collector is not None
            and isinstance(outputs, MoWOutput)
            and outputs.routing_score is not None
        ):
            routing_score_collector.append(outputs.routing_score)
        return super()._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder, num_new_tokens
        )

    def refine_router(
        self,
        hidden_states: torch.Tensor,
        adjacency_matrix: torch.Tensor,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
    ):
        if self.__router_backup is None:
            self.__router_backup = copy.deepcopy(self.router)

        experts = list(next(iter(self.router.values())).embedding_set.keys())

        # Shape: (num_types, num_layers, embedding_dim)
        embedding = torch.stack(
            [
                router.get_embedding(
                    hidden_states=hidden_states,
                    adjacency_matrix=adjacency_matrix,
                    relation_matrix=relation_matrix,
                    context=context,
                )
                for router in self.router.values()
            ]
        )

        # Shape: (num_types, num_experts, num_layers, embedding_dim)
        prototypes = torch.stack(
            [
                torch.stack(
                    [router.embedding_set[expert] for expert in experts]
                )
                for router in self.router.values()
            ]
        )

        # Shape: (num_types, num_experts, num_layers)
        domain_similarity = self.__similarity_fn(
            embedding[:, None, ...], prototypes
        )

        # Shape: (num_types, num_experts, num_experts, num_layers)
        prototype_similarities = self.__similarity_fn(
            prototypes[:, :, None, ...], prototypes[:, None, :, ...]
        )
        prototype_similarities = torch.softmax(
            prototype_similarities / self.refinement_temperature, dim=-1
        )

        # Shape: (num_types, num_layers, num_experts, embedding_dim)
        refinement_targets = torch.matmul(
            prototype_similarities.permute(0, 3, 1, 2),
            prototypes.permute(0, 2, 1, 3),
        )

        # Shape: (num_types, num_experts, num_layers, embedding_dim
        refinement_targets = refinement_targets.permute(0, 2, 1, 3)

        refinement_targets = {
            name: {key: target for key, target in zip(experts, targets)}
            for name, targets in zip(self.router, refinement_targets)
        }
        for i, (name, router) in enumerate(self.router.items()):
            for j, (key, emb) in enumerate(router.embedding_set.items()):
                amount = self.refinement_rate * domain_similarity[i, j]
                amount = amount[:, None]  # Shape: (num_layers, 1)
                t = (1 - amount) * emb + amount * refinement_targets[name][key]
                emb.copy_(t)

    def restore_router(self):
        if self.__router_backup is not None:
            self.router = self.__router_backup
            self.__router_backup = None
