from typing import Iterable, Literal, cast, overload

import torch
from peft.tuners.lora import Linear as LoraLinear
from peft.tuners.lora import LoraLayer
from torch import nn


class LoraLinearPreHook(nn.Module):
    def __init__(
        self, layer_num: int, shared_expert_names: str | list[str] = "shared"
    ):
        """
        Pre-hook for LoraLayer to handle the input tensor before the forward pass.
        This is used to ensure that the input tensor is cast to the correct dtype.
        """
        super().__init__()
        self.layer_num = layer_num
        self.routing_score: dict[str, torch.Tensor] | None = None

        if isinstance(shared_expert_names, str):
            shared_expert_names = [shared_expert_names]
        self.shared_expert_names = shared_expert_names

    def assign_routing_score(self, score: dict[str, torch.Tensor]):
        """
        Assign the routing score to be used in the forward pass.
        This is called before the forward pass to ensure the score is available.
        """
        example = next(iter(score.values()))
        self.routing_score = {
            **score,
            **{
                name: torch.ones_like(
                    example, dtype=example.dtype, device=example.device
                )
                for name in self.shared_expert_names
            },
        }

    def forward(
        self,
        lora_layer: LoraLayer,
        input: tuple[torch.Tensor, ...],
    ):
        assert (
            self.routing_score is not None
        ), "Routing score must be assigned before calling forward."
        return input[0], self.routing_score


def lora_linear_forward(
    self: LoraLinear,
    x: torch.Tensor,
    routing_score: dict[str, torch.Tensor] | None = None,
    *args,
    **kwargs,
) -> torch.Tensor:
    self._check_forward_args(x, *args, **kwargs)
    adapter_names = kwargs.pop("adapter_names", None)

    if routing_score is None:
        routing_score = self.scaling

    if self.disable_adapters:
        if self.merged:
            self.unmerge()
        result = self.base_layer(x, *args, **kwargs)
    elif adapter_names is not None:
        result = self._mixed_batch_forward(
            x, *args, adapter_names=adapter_names, **kwargs
        )
    elif self.merged:
        result = self.base_layer(x, *args, **kwargs)
    else:
        result = self.base_layer(x, *args, **kwargs)
        torch_result_dtype = result.dtype

        out = torch.zeros_like(result)

        lora_A_keys = self.lora_A.keys()
        for active_adapter in self.active_adapters:
            if active_adapter not in lora_A_keys:
                continue

            lora_A = self.lora_A[active_adapter]
            lora_B = self.lora_B[active_adapter]
            dropout = self.lora_dropout[active_adapter]
            x = self._cast_input_dtype(x, lora_A.weight.dtype)  # type: ignore

            scaling = routing_score[active_adapter]
            if isinstance(scaling, torch.Tensor):
                scaling = scaling.to(lora_A.weight.device)  # type: ignore
                while scaling.dim() < x.dim():
                    scaling = scaling.unsqueeze(-1)

            if not self.use_dora[active_adapter]:
                out = out + lora_B(lora_A(dropout(x))) * scaling
            else:
                if isinstance(dropout, nn.Identity) or not self.training:
                    base_result = result + out
                else:
                    x = dropout(x)
                    base_result = None

                out = out + self.lora_magnitude_vector[active_adapter](
                    x,
                    lora_A=lora_A,
                    lora_B=lora_B,
                    scaling=scaling,
                    base_layer=self.get_base_layer(),
                    base_result=base_result,
                )

        if (lora_mlp := getattr(self, "lora_mlp", None)) is not None:
            out = lora_mlp(out)
        result = result + out
        result = result.to(torch_result_dtype)

    return result


def iterate_lora_layers(model: nn.Module, target_modules: str | Iterable[str]):
    """
    Iterate through all LoraLayer instances in the model.
    """
    if isinstance(target_modules, str):
        target_modules = [target_modules]

    while not hasattr(model, "layers"):
        model = cast(nn.Module, model.base_model)

    for layer_num, layer in enumerate(model.layers):  # type: ignore
        base_attn = layer.self_attn
        for name in target_modules:
            if not hasattr(base_attn, name):
                continue

            lora_layer = getattr(base_attn, name)
            if not (
                isinstance(lora_layer, LoraLayer)
                and isinstance(lora_layer, nn.Module)
            ):
                continue

            yield name, layer_num, lora_layer


@overload
def get_router_entropy(
    routing_score: dict[str, torch.Tensor], layerwise: Literal[True]
) -> list[float]: ...
@overload
def get_router_entropy(
    routing_score: dict[str, torch.Tensor], layerwise: Literal[False] = False
) -> float: ...
def get_router_entropy(
    routing_score: dict[str, torch.Tensor], layerwise: bool = False
):
    if layerwise:
        return _get_router_entropy_layerwise(routing_score)

    entropy = [
        float(-x * torch.log(x))
        for score in routing_score.values()
        for x in score
        if x > 0
    ]
    return sum(entropy) / len(entropy) if entropy else 0.0


def _get_router_entropy_layerwise(routing_score: dict[str, torch.Tensor]):
    entropies = []
    for i in range(next(iter(routing_score.values())).shape[-1]):
        entropy = [
            float(-x * torch.log(x))
            for score in routing_score.values()
            if (x := score[i].mean()) > 0
        ]
        entropies.append(sum(entropy) / len(entropy) if entropy else 0.0)
    return entropies
