import torch
import math
from torch import nn
import torch.nn.functional as F
from typing import Optional, Literal, Tuple, Dict, Any


from layers.encoders import EmbeddingEncoder


class LoRALinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int,
        alpha: int,
        bias: bool,
        init_lora_weight: Optional[Literal["default", "pissa"]] = None,
        use_rslora: bool = False,
    ) -> None:
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.rank = rank
        self.alpha = alpha
        self.scaled_alpha = alpha / math.sqrt(rank) if use_rslora else alpha / rank
        self.merged = False

        self.lora_a = nn.Parameter(
            torch.empty(
                (rank, in_features),
                device=self.linear.weight.device,
            )
        )
        self.lora_b = nn.Parameter(
            torch.empty(
                (out_features, rank),
                device=self.linear.weight.device,
            )
        )

        self.reset_lora_parameters(init_lora_weight)

    def reset_lora_parameters(
        self, init_lora_weight: Literal["default", "pissa"]
    ) -> None:
        if init_lora_weight == "default":
            nn.init.kaiming_normal_(self.lora_a)
            nn.init.zeros_(self.lora_b)
        elif init_lora_weight == "pissa":
            Vr, Sr, Ur = torch.svd_lowrank(
                self.linear.weight.data.to(torch.float32), q=self.rank
            )
            Sr /= self.alpha / self.rank
            Uhr = Ur.t()

            self.lora_a.data = torch.diag(torch.sqrt(Sr)) @ Uhr
            self.lora_b.data = Vr @ torch.diag(torch.sqrt(Sr))
            self.linear.weight.data -= (
                (self.scaled_alpha) * self.lora_b @ self.lora_a
            ).to(self.linear.weight.dtype)

    def merge(self) -> None:
        if self.merged:
            return

        lora_dtype = self.lora_a.dtype
        dtype = self.linear.weight.dtype

        with torch.no_grad():
            self.linear.weight.data = (
                self.linear.weight.data.to(lora_dtype)
                + (self.alpha / self.rank) * self.lora_b @ self.lora_a
            ).to(dtype)

        self.merged = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self.linear(x)

        if self.merged:
            return output

        lora_output = F.linear(x.to(self.lora_a.dtype), self.lora_a)
        lora_output = F.linear(lora_output, self.lora_b)
        return (output + (self.scaled_alpha) * lora_output).to(self.linear.weight.dtype)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_linear: nn.Linear,
        rank: int,
        alpha: int,
        init_lora_weight: Literal["default", "pissa"],
    ) -> "LoRALinear":
        with torch.device("meta"):
            lora_linear = cls(
                pretrained_linear.in_features,
                pretrained_linear.out_features,
                rank,
                alpha,
                pretrained_linear.bias is not None,
            )
        lora_linear.to_empty(device=pretrained_linear.weight.device)

        lora_linear.linear.weight = pretrained_linear.weight
        if pretrained_linear.bias is not None:
            lora_linear.linear.bias = pretrained_linear.bias

        lora_linear.reset_lora_parameters(init_lora_weight)
        return lora_linear


class HyperLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        initial_vocab_size: int,
        embedding_encoder: EmbeddingEncoder,
        bias: bool,
        pad_token_id: Optional[int] = None,
    ) -> None:
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)

        self.initial_vocab_size = initial_vocab_size
        self.pad_token_id = pad_token_id
        self.embedding_encoder = embedding_encoder

    def forward(
        self,
        x: torch.Tensor,
        codebook_tensor: Optional[torch.Tensor] = None,
        metadata: Dict[str, Any] = {},
    ) -> Tuple[torch.Tensor, Dict[str, Any]]:
        # x (B, S, D)
        output = self.linear(x)  # (B, S, V)
        if codebook_tensor is None or codebook_tensor.numel() == 0:
            return output, metadata

        hyper_weight, encoder_metadata = self.embedding_encoder(
            codebook_tensor, self.linear.weight, self.pad_token_id
        )  # (B, V_E, D)
        metadata["hyper_weight"] = hyper_weight

        hyper_output = torch.bmm(
            x, hyper_weight.transpose(-2, -1)
        )  # (B, S, V_E) where the [:,:,V_E_used:] are zeros
        """hyper_output[0]:
        tensor([[-0.1182,  0.0226,  0.1094,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2471,  0.0009,  0.2041,  ...,  0.0000,  0.0000,  0.0000],
        [-0.2852,  0.0160,  0.2344,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.1357,  0.0835,  0.1689,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1377,  0.1069,  0.1963,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1338,  0.0564,  0.2432,  ...,  0.0000,  0.0000,  0.0000]],
        """
        # TODO, maybe we need to mask out the empty slots

        for key, value in encoder_metadata.items():
            metadata[key] = value

        return (
            torch.cat(
                [
                    output[..., : self.initial_vocab_size],
                    hyper_output,
                    output[..., self.initial_vocab_size :],
                ],
                dim=-1,
            ),
            metadata,
        )

    @classmethod
    def from_pretrained(
        cls,
        pretrained_linear: nn.Linear,
        initial_vocab_size: int,
        embedding_encoder: EmbeddingEncoder,
        bias: bool,
        pad_token_id: Optional[int] = None,
    ) -> "HyperLinear":
        hyper_linear = cls(
            pretrained_linear.in_features,
            pretrained_linear.out_features,
            initial_vocab_size,
            embedding_encoder,
            bias,
            pad_token_id,
        ).to(device=pretrained_linear.weight.device)

        hyper_linear.linear.weight = pretrained_linear.weight
        if pretrained_linear.bias is not None:
            hyper_linear.linear.bias = pretrained_linear.bias

        return hyper_linear
