from peft import PrefixTuningConfig, TaskType, get_peft_model

from egu.models.base import HFModel


class PTuneHFModel(HFModel):
    def __init__(
        self,
        model_name,
        model_path=None,
        config_path="./config",
        generation_config=None,
        prefix_tune_len=1,
        prefix_projection=True,
        inference_mode=False,
    ):
        super().__init__(model_name, model_path, config_path, generation_config)

        peft_config = PrefixTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            num_virtual_tokens=prefix_tune_len,
            prefix_projection=prefix_projection,  # adding MLP
            inference_mode=inference_mode,  # freeze the prefix
        )
        self.model = get_peft_model(self.model, peft_config)

    def __call__(self, *args, **kwargs):
        for key in ["prompts", "answers"]:
            if key in kwargs:
                kwargs.pop(key, None)
        return self.model(*args, **kwargs)

    def generate(self, *args, **kwargs):
        for key in ["prompts"]:
            if key in kwargs:
                kwargs.pop(key, None)
        return self.model.generate(*args, **kwargs)
