import torch
from torch import nn
import torch.nn.functional as F

from inference.torch.codebook import CodebookManager
from inference.torch.layers.encoder import CallableEncoder


class HyperLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        initial_vocab_size: int,
        callable_encoder: CallableEncoder,
        bias: bool,
        pad_token_id: int,
        codebook_manager: CodebookManager,
    ) -> None:
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias)

        self.pad_token_id = pad_token_id
        self.callable_encoder = callable_encoder
        self.codebook_manager = codebook_manager
        self.initial_vocab_size = initial_vocab_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self.linear(x)

        hyper_linear = self.codebook_manager.get_hyper_linear(
            self.linear.weight, self.callable_encoder
        )

        hyper_output = F.linear(x, hyper_linear)
        hyper_mask = (
            torch.arange(end=hyper_output.shape[-1], device=x.device)
            < self.codebook_manager.get_index()
        )

        return torch.cat(
            [
                output[..., : self.initial_vocab_size],
                hyper_output * hyper_mask,
                output[..., self.initial_vocab_size :],
            ],
            dim=-1,
        )

    @staticmethod
    def from_linear(
        linear: nn.Linear,
        initial_vocab_size: int,
        callable_encoder: CallableEncoder,
        pad_token_id: int,
        codebook_manager: CodebookManager,
    ) -> "HyperLinear":
        with torch.device("meta"):
            hyper_linear = HyperLinear(
                linear.weight.shape[1],
                linear.weight.shape[0],
                initial_vocab_size,
                callable_encoder,
                linear.bias is not None,
                pad_token_id,
                codebook_manager,
            )
        hyper_linear.to_empty(device=linear.weight.device)

        hyper_linear.linear.weight = linear.weight
        if linear.bias is not None:
            hyper_linear.linear.bias = linear.bias

        return hyper_linear
