import torch
from torch import nn
from typing import Dict
from transformers import PreTrainedModel, AutoModelForCausalLM

from configs import Config
from inference.torch.layers.encoder import Encoder
from inference.torch.codebook import CodebookManager
from inference.torch.layers.linear import HyperLinear
from inference.torch.layers.embedding import HyperEmbedding


class LZWModel:
    def __init__(
        self,
        model: PreTrainedModel,
        encoder: Encoder,
        codebook_manager: CodebookManager,
    ) -> None:
        self._model = model
        self._encoder = encoder
        self._codebook_manager = codebook_manager

    def __getattr__(self, attr):
        if attr.startswith("_"):
            return getattr(self, attr)
        return getattr(self._model, attr)

    def _load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
        encoder_state_dict = {
            "position_embeddings": state_dict[
                "hyper_embedding._orig_mod.embedding_encoder.pos_embed.weight"
            ]
        }

        for key in state_dict.keys():
            if "hyper_embedding" in key and "embedding_encoder.layers" in key:
                encoder_state_dict[key.split("embedding_encoder.")[1]] = state_dict[key]

        self._encoder.load_state_dict(encoder_state_dict, strict=True)
        self._model.load_state_dict(state_dict, strict=False)

    def _to(self, *args, **kwargs) -> None:
        self._encoder.to(*args, **kwargs)
        self._model.to(*args, **kwargs)

    def _eval(self) -> None:
        self._encoder.eval()
        self._model.eval()

    @classmethod
    def from_config(
        cls,
        config: Config,
        device: torch.device,
        pad_token_id: int,
    ) -> PreTrainedModel:
        model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
            config.pretrained_model_name_or_path
        ).to(device)

        encoder = Encoder.from_config(config).to(device, config.dtype)
        encoder.compile()

        with torch.no_grad():
            # trigger compilation for the maximum number of updates
            updates = torch.zeros(
                config.extra_vocab_size,
                config.compression.max_subtokens,
                device=device,
                dtype=torch.long,
            )

            weight = torch.zeros(
                model.vocab_size,
                config.embedding_encoder.embedding_size,
                dtype=config.dtype,
                device=device,
            )

            _ = encoder(updates, weight, pad_token_id)

            # trigger compilation for 1 update
            updates = torch.zeros(
                1,
                config.compression.max_subtokens,
                device=device,
                dtype=torch.long,
            )

            weight = torch.zeros(
                model.vocab_size,
                config.embedding_encoder.embedding_size,
                dtype=config.dtype,
                device=device,
            )

            _ = encoder(updates, weight, pad_token_id)

            # trigger compilation for 2 updates
            updates = torch.zeros(
                2,
                config.compression.max_subtokens,
                device=device,
                dtype=torch.long,
            )

            weight = torch.zeros(
                model.vocab_size,
                config.embedding_encoder.embedding_size,
                dtype=config.dtype,
                device=device,
            )

            _ = encoder(updates, weight, pad_token_id)

            print("encoder compiled")

        codebook_manager = CodebookManager(
            config.initial_vocab_size,
            config.extra_vocab_size,
            config.compression.max_subtokens,
            config.embedding_encoder.embedding_size,
            config.dtype,
            device,
            pad_token_id,
        )

        model.set_input_embeddings(
            HyperEmbedding.from_embedding(
                model.get_input_embeddings(),
                config.initial_vocab_size,
                config.extra_vocab_size,
                create_encoder_callable(encoder),
                pad_token_id,
                codebook_manager,
            )
        )

        model.set_output_embeddings(
            HyperLinear.from_linear(
                model.get_output_embeddings(),
                config.initial_vocab_size,
                create_encoder_callable(encoder),
                pad_token_id,
                codebook_manager,
            )
        )

        return cls(model, encoder, codebook_manager)


def create_encoder_callable(encoder):
    def encoder_fn(codebook, embeddings, pad_token_id):
        return encoder(codebook, embeddings, pad_token_id)

    return encoder_fn
