from peft import LoraConfig, TaskType, get_peft_model

from egu.models.base import HFModel


class LoraTuneHFModel(HFModel):
    def __init__(
        self,
        model_name,
        model_path=None,
        config_path="./config",
        generation_config=None,
        lora_bits=4,
        ref=False,
        lora_rank=8,
        lora_alpha=32,
    ):
        super().__init__(model_name, model_path, config_path, generation_config, ref)

        # loftq_config = LoftQConfig(loftq_bits=lora_bits)
        # lora_config = LoraConfig(
        #     r=8,
        #     lora_alpha=16,
        #     lora_dropout=0.1,
        #     bias="none",
        #     task_type=TaskType.CAUSAL_LM,
        # )
        print("Lora class")
        print(f"lora rank {lora_rank}")
        print(f"lora alpha {lora_alpha}")

        lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            target_modules=["q_proj", "v_proj"],
            lora_dropout=0.1,
            bias="none",
            task_type=TaskType.CAUSAL_LM,
        )
        self.model = get_peft_model(self.model, lora_config)

    def __call__(self, *args, **kwargs):
        # Remove the "prompts" key from the kwargs if it exists
        for key in ["prompts", "answers"]:
            if key in kwargs:
                kwargs.pop(key, None)
        return self.model(*args, **kwargs)

    def generate(self, *args, **kwargs):
        # Remove the "prompts" key from the kwargs if it exists
        for key in ["prompts"]:
            if key in kwargs:
                kwargs.pop(key, None)
        return self.model.generate(*args, **kwargs)
