
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training 
from peft import PeftModel
from torch import nn
from transformers import AutoModel
from transformers import AutoTokenizer
from transformers import LlamaTokenizer, LlamaForCausalLM, GenerationConfig
from watermark import watermark

class LlmEmbeddingExtractionModel(nn.Module):
    def __init__(self, model_name, config=None):
        """
        Args:
            model_name: huggingface model name
            config: used to config lora
        """
        super(LlmEmbeddingExtractionModel, self).__init__()
        print(f'[INFO] Load {model_name}')

        if config.TRAINER.LLM_LORA.USE:
            print(f'[INFO] Config 8bit model...')
            model = AutoModel.from_pretrained(model_name)

            model = prepare_model_for_int8_training(model)
            model.config.use_cache = False  

            print(f'[INFO] Config LoRA...')
            lora_config = LoraConfig(
                r=config.TRAINER.LLM_LORA.RANK,  
                lora_alpha=config.TRAINER.LLM_LORA.LORA_ALPHA,
                target_modules=config.TRAINER.LLM_LORA.TARGET_MODULES, 
                lora_dropout=config.TRAINER.LLM_LORA.LORA_DROPOUT,
                bias=config.TRAINER.LLM_LORA.BIAS,
                modules_to_save=config.TRAINER.LLM_LORA.MODULES_TO_SAVE, 
                task_type=config.TRAINER.LLM_LORA.TASK_TYPE,
                
            )
            model = get_peft_model(model, lora_config)
            print('[INFO]: LLM LoRA parameters: ')
            model.print_trainable_parameters()
        else:
            model = AutoModel.from_pretrained(model_name)
        self.model = model

    def forward(self, input_ids, attention_mask):
        output = self.model(
            input_ids, attention_mask=attention_mask
        )

        embeddings = output['last_hidden_state'][:, -1, :] 

        return embeddings


