from adapters import AutoAdapterModel, SeqBnConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, LoraConfig
import torch

def LMAdapter(model_type, task_type, out_size, reduction_factor=16):
    # config = AutoConfig.from_pretrained(model_type)
    # model = AutoAdapterModel.from_config(config)
    model = AutoAdapterModel.from_pretrained(model_type)
    tokenizer = AutoTokenizer.from_pretrained(model_type)
    if task_type == 'classification':
        model.add_classification_head('clf', num_labels=out_size)
        adapter_config = SeqBnConfig(mh_adapter=True, output_adapter=True, reduction_factor=reduction_factor, non_linearity='relu')
        model.add_adapter('seqbn', config=adapter_config)
        model.train_adapter('seqbn')      # freeze pretrained weight and activate adapter training
    if task_type == 'summarization':
        if 'llama' in model_type:
            model.add_causal_lm_head('summarization')
        adapter_config = SeqBnConfig(mh_adapter=True, output_adapter=True, reduction_factor=reduction_factor, non_linearity='relu')
        model.add_adapter("summarization", config=adapter_config)
        model.train_adapter("summarization")
    return model, tokenizer

def PromptTuning(model_type, task_type, out_size, reduction_factor=16):
    # config = AutoConfig.from_pretrained(model_type)
    # model = AutoAdapterModel.from_config(config)
    # model = AutoModelForCausalLM.from_pretrained(model_type, torch_dtype=torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained(model_type)
    tokenizer = AutoTokenizer.from_pretrained(model_type)
    if task_type == "summarization":
        peft_config = PromptTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.TEXT,
            num_virtual_tokens=20,
            prompt_tuning_init_text="Summarize:",
            tokenizer_name_or_path=model_type,
        )
    else:
        peft_config = LoraConfig(
            r=2,
            lora_alpha=2,
            lora_dropout=0.1,
            bias="none",
            task_type="SEQ_CLS", 
        )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()

    return model, tokenizer