import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
)
from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, LoraConfig

def HFModel(model_name_or_path, peft_config, task, out_size=1, torch_dtype=None):
    print(f"Loading model in torch_dtype = {torch_dtype}")
    if task == "summarization":
        model = AutoModelForCausalLM.from_pretrained(model_name_or_path, 
                                                     torch_dtype=torch_dtype) 
    elif task == "classification":
        model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path,
                                                                   num_labels=out_size,
                                                                   torch_dtype=torch_dtype)
    elif task == "qa":
        model = AutoModelForCausalLM.from_pretrained(model_name_or_path, 
                                                     torch_dtype=torch_dtype) 
        peft_config.pop("type", None)
        peft_config = LoraConfig(
            **peft_config
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    else:
        raise NotImplementedError(f"Task {task} is not supported.")
    
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id
    return model, tokenizer
