#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import copy
import logging
import os
import sys
import pprint
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence

import numpy as np
import torch
import transformers
from datasets import load_dataset
from torch.utils.data import Dataset, IterableDataset
from transformers import Trainer as TrainerHf

import utils
from data import register_data_module_makers, make_supervised_data_module
# from model.modeling_llama import LlamaForCausalLM as LlamaWFlashAttnForCausalLM
from transformers import LlamaForCausalLM
from peft import get_peft_model, LoraConfig
from peft_lora_utils import SavePeftModelCallback, SavePeftModelAtEndCallback, resume_lora_model_from_checkpoint
from tokenizer import IGNORE_INDEX, DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN

from model import ModelArchNames

from args import LlamaTrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default=None)
    model_config_path: Optional[str] = field(default=None)
    model_size: Optional[str] = field(default="")  # deprecated, don't use it
    tokenizer_path: Optional[str] = field(default="facebook/opt-125m")
    model_arch: Optional[str] = field(
        default=ModelArchNames.LLAMA2)  # see `ModelArchNames`


@dataclass
class DataArguments:
    data_path: str = field(default=None,
                           metadata={"help": "Path to the training data."})
    eval_num: int = field(default=0, metadata={"help": "Num of eval samples."})
    data_module_maker: str = field(
        default=None, metadata={"help": "which dataset module maker to use"}
    )  # 其实不是一个 module，这个诡异的 naming 来自 alpaca 把 train / eval / test 3 个 dataset 命名为 data module。

def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
                                   output_dir: str):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {
            key: value.cpu()
            for key, value in state_dict.items()
        }
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa


def smart_tokenizer_and_embedding_resize(
        special_tokens_dict: Dict, tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel, model_args: ModelArguments):
    '''
    沿用 Alpaca，llama 和 baichuan 都没有 pad token，给 model 补一个 pad
    token。老版本的 baichuan 曾经错误地放出一个 tokenizer 带有 pad token，but
    it's fixed already。
    而 qwen 虽然也是 llama 结构，tokenizer 不同，bos / eos / pad 是同一个 token id，
    而且 model embedding vocab size >> tokenizer len，不需要 resize。
    '''

    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
            dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def train():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, LlamaTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if model_args.model_config_path is not None:
        config_path = model_args.model_config_path
    elif model_args.model_size == "":
        config_path = os.path.join(model_args.model_name_or_path,
                                   "config.json")
    else:
        print('use model_config_path instead of deprecated model_size')
        config_path = os.path.join(
            model_args.model_name_or_path,
            "config_{}.json".format(model_args.model_size))
    config_data = utils.load_json(config_path)

    # setup tokenizer
    if model_args.model_arch in [
            ModelArchNames.LLAMA, ModelArchNames.LLAMA_W_FLASH_ATTN,
            ModelArchNames.LLAMA2
    ]:
        tokenizer = transformers.LlamaTokenizer.from_pretrained(
            model_args.tokenizer_path,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
        )
        special_tokens_dict = dict()
        if tokenizer.pad_token is None:
            special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
        assert tokenizer.bos_token is not None
        assert tokenizer.eos_token is not None
        assert tokenizer.unk_token is not None

    else:
        raise ValueError(f'unknown model arch {model_args.model_args}')

    # setup model
    assert not training_args.use_pipeline
    torch.set_default_tensor_type(torch.cuda.HalfTensor)
    if model_args.model_arch == ModelArchNames.LLAMA:
        if training_args.use_flash_attn:
            # backward compat
            model = LlamaWFlashAttnForCausalLM.from_pretrained(
                model_args.model_name_or_path)
        else:
            model = transformers.LlamaForCausalLM.from_pretrained(
                model_args.model_name_or_path)
    elif model_args.model_arch == ModelArchNames.LLAMA_W_FLASH_ATTN:
        model = LlamaWFlashAttnForCausalLM.from_pretrained(
            model_args.model_name_or_path)
    elif model_args.model_arch == ModelArchNames.LLAMA2:
        model = LlamaForCausalLM.from_pretrained(
            model_args.model_name_or_path)
    else:
        raise ValueError(f'unknown model arch {model_args.model_args}')
    print(f'using model arch {type(model)}')
    torch.set_default_tensor_type(torch.FloatTensor)

    if training_args.use_lora:
        if training_args.use_flash_attn:
            lora_target_modules = ["Wqkv"]
        else:
            lora_target_modules = [
                "q_proj",
                "v_proj",
            ]
        config = LoraConfig(
            r=training_args.lora_dim,
            lora_alpha=training_args.lora_alpha,
            target_modules=lora_target_modules,
            lora_dropout=training_args.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
        )
        model = get_peft_model(model, config)
        model.print_trainable_parameters()
        if training_args.lora_model_path:
            resume_lora_model_from_checkpoint(model,
                                              training_args.lora_model_path)

    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
        model_args=model_args)

    # setup dataset
    register_data_module_makers()
    data_module = make_supervised_data_module(tokenizer=tokenizer,
                                              data_args=data_args,
                                              training_args=training_args)

    # train
    if training_args.use_pipeline:
        raise ValueError('HF alpaca + PP is deprecated')
    else:
        trainer = TrainerHf(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            callbacks=[SavePeftModelCallback, SavePeftModelAtEndCallback]
            if training_args.use_lora else None,
            **data_module)
        trainer.train()


if __name__ == "__main__":
    train()
