import torch
from typing import List, Optional
from fast_compression import CodebookManager as FastCodebookManager

from inference.torch.layers.encoder import CallableEncoder


class CodebookManager:
    def __init__(
        self,
        initial_vocab_size: int,
        max_codebook_size: int,
        max_subtokens: int,
        embedding_dim: int,
        dtype: torch.dtype,
        device: torch.device,
        pad_token_id: int,
        disabled_ids: Optional[List[int]] = None,
    ):
        self.device = device
        self.pad_token_id = pad_token_id
        self.max_subtokens = max_subtokens
        self.embedding_dim = embedding_dim
        self.max_codebook_size = max_codebook_size
        self.initial_vocab_size = initial_vocab_size

        self.codebook_manager = FastCodebookManager(
            initial_vocab_size,
            max_codebook_size,
            max_subtokens,
            pad_token_id,
            disabled_ids,
        )

        self.index = 0
        self.hyper_embedding = torch.zeros(
            self.max_codebook_size, embedding_dim, dtype=dtype, device=self.device
        )
        self.hyper_linear = torch.zeros(
            self.max_codebook_size, embedding_dim, dtype=dtype, device=self.device
        )

        self.updates = None
        self.num_updates = 0

    def get_hyper_embedding(
        self,
        ids: torch.LongTensor,
        weight: torch.Tensor,
        callable_encoder: CallableEncoder,
    ) -> torch.Tensor:
        if not torch.compiler.is_compiling():
            ids_list = ids.tolist()
            updates, num_updates = self.codebook_manager.update_codebook(
                ids_list, len(ids_list) > 1
            )
            self.updates = torch.tensor(
                updates,
                device=self.device,
                dtype=torch.long,
            )
            self.num_updates = num_updates

            if self.num_updates > 0:
                self.hyper_embedding[self.index : self.index + self.num_updates] = (
                    callable_encoder(
                        self.updates, weight, self.pad_token_id
                    )[: self.num_updates]
                )

        return self.hyper_embedding

    def get_hyper_linear(
        self, weight: torch.Tensor, callable_encoder: CallableEncoder
    ) -> torch.Tensor:
        if not torch.compiler.is_compiling():
            if self.num_updates > 0:
                self.hyper_linear[self.index : self.index + self.num_updates] = (
                    callable_encoder(
                        self.updates, weight, self.pad_token_id
                    )[: self.num_updates]
                )
                self.index += self.num_updates

        return self.hyper_linear

    def get_index(self) -> int:
        return self.index

    def reset(self) -> None:
        self.codebook_manager.reset()

        self.index = 0
        self.hyper_embedding = torch.zeros(
            self.max_codebook_size,
            self.embedding_dim,
            dtype=self.dtype,
            device=self.device,
        )
        self.hyper_linear = torch.zeros(
            self.max_codebook_size,
            self.embedding_dim,
            dtype=self.dtype,
            device=self.device,
        )
