
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import logging
import math
import os
import shutil

import torch
import transformers
import numpy as np

from peft import PeftModel
from accelerate.utils import set_seed
from tqdm import tqdm
from vllm import LLM, SamplingParams

import sys; sys.path.append("src")
from sft_utils import (
    MODEL_DICT, TEST_BATCH_SIZE, TEST_GEN_TOKENS, STOP_TOKENS, COMMONSENSE_DICT,
    ModelArguments, DataArguments, OpArguments,
    prepare_evaluation_data, check_special_tokens, smart_tokenizer_and_embedding_resize,
    retrieve_prediction_answer, compute_accuracy, compute_accuracy_math
)


def save_peft_for_vllm(model_args):
    ##########################
    #     Initialization     #
    ##########################
    if model_args.full_precision:
        model = transformers.AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            low_cpu_mem_usage=True,
            torch_dtype=torch.bfloat16,
            device_map='auto',
        )
    else:
        model = transformers.AutoModelForCausalLM.from_pretrained(
            model_args.model_name_or_path,
            low_cpu_mem_usage=True,
            torch_dtype=torch.bfloat16,
            device_map='auto',
            quantization_config=transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_compute_dtype=torch.bfloat16,
                bnb_4bit_use_double_quant=False,
                bnb_4bit_quant_type='nf4',
            ),
        )
    
    is_common_sense = any([data_name in model_args.adapter_name_or_path for data_name in COMMONSENSE_DICT.keys()])
    if is_common_sense:
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            padding_side="left",
        )
    else:
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            model_max_length=model_args.model_max_length,
            padding_side="left",
            use_fast=False,
        )
    special_tokens_dict = check_special_tokens(tokenizer, model_args.model_tag)
    smart_tokenizer_and_embedding_resize(
        special_tokens_dict=special_tokens_dict,
        tokenizer=tokenizer,
        model=model,
    )

    ##########################
    #       Peft Model       #
    ##########################
    if model_args.adapter_name_or_path is not None:
        model = PeftModel.from_pretrained(
            model,
            model_args.adapter_name_or_path,
            is_trainable=False,
            autocast_adapter_dtype=False
        ).merge_and_unload()
    else:
        raise ValueError(f"Adapter path is not provided!")

    ##########################
    #       Save Model       #
    ##########################
    model_args.temp_out_dir = model_args.adapter_name_or_path + '-vllm'
    
    if os.path.exists(model_args.temp_out_dir):
        logging.warning("Removing legacy peft model...")
        shutil.rmtree(model_args.temp_out_dir)
    os.makedirs(model_args.temp_out_dir, exist_ok=True)

    logging.warning("Saving temporary peft model...")
    model.eval()
    tokenizer.save_pretrained(model_args.temp_out_dir)
    model.save_pretrained(
        model_args.temp_out_dir, safe_serialization=True
    )
    logging.warning("Saved temporary peft model...")


def evaluation(model_args, data_args, op_args):
    ##########################
    #     Initialization     #
    ##########################
    llm = LLM(model=model_args.temp_out_dir, tensor_parallel_size=torch.cuda.device_count())
    stop_tokens = [llm.get_tokenizer().eos_token]
    if data_args.data_tag in ['math', 'viggo', 'sql'] + list(COMMONSENSE_DICT.keys()):
        stop_tokens += STOP_TOKENS
    if data_args.data_tag in list(COMMONSENSE_DICT.keys()):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=TEST_GEN_TOKENS[data_args.data_tag],
            stop=stop_tokens,
            # best_of=4,
            # use_beam_search=True,
        )
    else:
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=TEST_GEN_TOKENS[data_args.data_tag],
            stop=stop_tokens
        )

    ######################
    #      Dataset       #
    ######################
    question, answer = prepare_evaluation_data(data_args.data_tag)

    logging.warning("Batching inputs...")
    eval_step = math.ceil(len(question)/data_args.batch_size)
    logging.warning(f"Total example: {len(question)} | eval batch size: {data_args.batch_size} | "
                    f"eval steps: {eval_step}")
    question_data = []
    for i in tqdm(range(eval_step), desc="Tokenization"):
        if i < eval_step - 1:
            batch = question[i*data_args.batch_size: (i+1)*data_args.batch_size]
        else:
            batch = question[i*data_args.batch_size:]
        question_data.append(batch)

    
    ans_pred_list = []
    set_seed(42)
    for step, batch in enumerate(tqdm(question_data, desc="Prediction")):
        try:
            outputs = llm.generate(
                batch,
                sampling_params,
            )
        except Exception as e:
            print(f"Inference not finished, skip!")
            return None

        for output in outputs:
            generated_text = output.outputs[0].text
            ans_pred_list.append(retrieve_prediction_answer(generated_text, data_args.data_tag))

    
    adapter_info = '-'.join(model_args.adapter_name_or_path.split('/')[-3:-1])
    if data_args.data_tag in ['viggo', 'sql'] +  list(COMMONSENSE_DICT.keys()):
        keep_idx = [ans != '' for ans in ans_pred_list]
        keep_ratio = sum(keep_idx) / len(keep_idx) * 100
        accuracy = compute_accuracy(np.array(answer)[keep_idx], np.array(ans_pred_list)[keep_idx])
        f_accuracy = compute_accuracy(answer, ans_pred_list)

        print(f"Adapter: {adapter_info} | epoch: {op_args.iter + 1} | {data_args.data_tag.upper()} test accuracy: {100*accuracy:.2f}% | "
            f"full accuracy: {100*f_accuracy:.2f}% | keep ratio: {keep_ratio:.2f}% | full precision: {model_args.full_precision}")
    else:
        if data_args.data_tag == 'math':
            accuracy, num_invalid = compute_accuracy_math(ans_pred_list, answer)
            print(f"num of invalid outputs: {num_invalid} | ratio of invalid outputs: {100*num_invalid/len(answer):.2f}%")
        else:
            accuracy = compute_accuracy(ans_pred_list, answer)

        print(f"Adapter: {adapter_info} | epoch: {op_args.iter + 1} | {data_args.data_tag.upper()} test accuracy: {100*accuracy:.2f}% | "
            f"full precision: {model_args.full_precision}")
    
    if op_args.retain == 0:
        logging.warning("Deleting temporary peft model...")
        shutil.rmtree(model_args.temp_out_dir)


if __name__ == "__main__":
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, OpArguments))
    model_args, data_args, op_args = parser.parse_args_into_dataclasses()
    
    if model_args.model_name_or_path is None:
        model_args.model_name_or_path = MODEL_DICT[model_args.model_tag]
    
    data_args.batch_size = TEST_BATCH_SIZE[data_args.data_tag]

    if model_args.ckpt_dir is not None:
        adapter_dir_list = [os.path.join(model_args.ckpt_dir, ckpt_dir) for ckpt_dir in os.listdir(model_args.ckpt_dir)
                            if 'checkpoint-' in ckpt_dir and 'vllm' not in ckpt_dir]
        adapter_dir_list = sorted(adapter_dir_list, key=lambda x: int(x.split('-')[-1]))

        # if len(adapter_dir_list) > 3:
        #     adapter_dir_list = adapter_dir_list[:3]
    elif model_args.adapter_name_or_path is not None:
        adapter_dir_list = [model_args.adapter_name_or_path]
    else:
        raise ValueError("No checkpoint directory or adapter directory is provided!")


    model_args.adapter_name_or_path = adapter_dir_list[op_args.iter]
    if op_args.process == 'save':
        save_peft_for_vllm(model_args)
    elif op_args.process == 'eval':
        model_args.temp_out_dir = model_args.adapter_name_or_path + '-vllm'
        evaluation(model_args, data_args, op_args)
    else:
        raise ValueError("Invalid operation!")
