# Modified from https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py

import os
import torch
import transformers

from transformers import Trainer
from peft import PeftModel, LoraConfig, C3AConfig, TaskType, get_peft_model

import sys; sys.path.append("src")
from sft_utils import (
    MODEL_DICT, ModelArguments, DataArguments, TrainingArguments,
    make_supervised_data_module, check_special_tokens, smart_tokenizer_and_embedding_resize
)

CONFIG_DICT={
    "c3a": C3AConfig,
    "lora": LoraConfig,
    "dora": LoraConfig,
}


def train():
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    # Initialization
    if model_args.model_name_or_path is None:
        model_args.model_name_or_path = MODEL_DICT[model_args.model_tag]

    if not training_args.run_name:
        # training_args.run_name = f"{model_args.model_tag}-{data_args.data_tag}-{training_args.adapter_type}"
        training_args.run_name = f"{training_args.adapter_type}-lr_{training_args.learning_rate}"
        if int(training_args.lora_alpha) == training_args.lora_alpha:
            training_args.lora_alpha = int(training_args.lora_alpha)

        if training_args.adapter_type in ['lora', 'dora']:
            training_args.run_name += f'-r_{training_args.lora_r}-a_{training_args.lora_alpha}-d_{training_args.lora_dropout}'
        elif training_args.adapter_type in ['c3a',]:
            training_args.run_name += f'-block_size_{training_args.block_size}'
    
    batch_size = int(training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)
    output_dir = os.path.join(training_args.output_dir, 'models', data_args.data_tag,
                              f'{model_args.model_tag}-bs_{batch_size}', training_args.run_name)
    training_args.output_dir = output_dir

    ##################################################
    ### Load LLM Model
    ##################################################
    if model_args.full_precision:
        model = transformers.AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            low_cpu_mem_usage=True,
            torch_dtype=torch.bfloat16,
            device_map="auto"
        )
    else:
        model = transformers.AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            low_cpu_mem_usage=True,
            torch_dtype=torch.bfloat16,
            quantization_config=transformers.BitsAndBytesConfig(
                load_in_4bit=training_args.load_in_4bit,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=False,
                bnb_4bit_quant_type='nf4',
            ),
        )
    
    ##################################################
    ### Peft Model
    ##################################################
    if model_args.lora_init:
        config_kwargs = {
            "task_type": TaskType.CAUSAL_LM
        }
        if any(name in model_args.model_name_or_path.lower() for name in ["llama", "mistral", "falcon"]):
            target_modules = ["q_proj", "v_proj"]
        elif any(name in model_args.model_name_or_path.lower() for name in ["phi"]):
            target_modules = ["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"]
        else:
            raise ValueError(f"Only support LLAMA, Mistral, Falcon, Phi-2, but got {model_args.model_name_or_path}.")
        config_kwargs['target_modules'] = target_modules

        if training_args.adapter_type in ["c3a",]:
            config_kwargs.update({
                "block_size": training_args.block_size,
                "init_weights": "xavier_uniform",
            })
        elif training_args.adapter_type in ["lora", "dora"]:
            config_kwargs.update({
                "r": training_args.lora_r,
                "lora_alpha": training_args.lora_alpha,
                "lora_dropout": training_args.lora_dropout,
                "use_dora": (training_args.adapter_type == "dora")
            })
        else:
            raise ValueError(f"Unsupported lora adapter: {training_args.adapter_type}")
        
        adapter_conifg = CONFIG_DICT[training_args.adapter_type](**config_kwargs)
        model = get_peft_model(model, adapter_conifg, autocast_adapter_dtype=False)
    elif model_args.adapter_name_or_path is not None:
        model = PeftModel.from_pretrained(
            model,
            model_args.adapter_name_or_path,
            is_trainable=True,
        )
    else:
        raise ValueError(f"Unsupported lora adapter: {training_args.adapter_type}")

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        model_max_length=model_args.model_max_length,
        padding_side="left",
        use_fast=False,
    )
    special_tokens_dict = check_special_tokens(tokenizer, model_args.model_tag)
    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
    )
    model.generation_config.pad_token_id = tokenizer.pad_token_id

    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
        
    trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
    trainer.train()


if __name__ == "__main__":
    train()