from dataclasses import dataclass
from typing import Optional

from serde import deserialize_from_json

from .model import CompletionModel, ModelType


class HFModelType(ModelType):
    GPT2 = "gpt-2"
    GEMMA_2B = "gemma-2b"


@dataclass
class HFModelConfig:
    model_link: str
    seed: int
    max_tokens: int
    use_cache: bool = False
    do_quant: bool = False


class HFCompletionModel(CompletionModel):
    def __init__(self, model_type: HFModelType, adapter_link: Optional[str] = None):
        super().__init__(model_type)
        self._model = None
        self.adapter_link = adapter_link

    @property
    def model(self):
        if self._model is None:
            self.initialize_model()
        assert self._model is not None
        return self._model

    def initialize_model(self):
        import torch as t
        from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

        self.device = "cuda" if t.cuda.is_available() else "cpu"

        self.config = deserialize_from_json(HFModelConfig, self.config_dict)
        if not self.config.do_quant:
            self._model = AutoModelForCausalLM.from_pretrained(
                self.config.model_link
            ).to(self.device)
        else:
            self._model = AutoModelForCausalLM.from_pretrained(
                self.config.model_link,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
            )

        if self.adapter_link is not None:
            self._model.load_adapter(self.adapter_link)

        self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_link)
        if self.tokenizer.pad_token_id is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        # self.ref_model = create_reference_model(self.model)

        self._model.config.use_cache = self.config.use_cache

    def completion(self, prompt: str) -> str:
        """Run a completion on the model.

        Args:
            prompt: the text to complete

        Returns:
            the completed text
        """
        ret = self.model.generate(
            self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device),
            max_length=self.config.max_tokens,
            num_return_sequences=1,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        return self.tokenizer.decode(ret[0], skip_special_tokens=True)
