"""
Script to distill pretrained Transformers into linear attention variants
"""
import sys
import os
from os.path import join

import argparse
import torch
from omegaconf import OmegaConf
sys.path.append('./src')
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

from utils.setup import (
    seed_everything,
    update_model_config_from_args,
)
from utils.logging import print_config, print_header

from model.pretrained import get_pretrained_loader
from model.load_model import load_and_convert_attns
import torch.distributed as dist
import datetime
from tqdm import tqdm
from utils.rotation_utils import add_rotations
from lm_eval.models.huggingface import HFLM
from lm_eval.utils import make_table
from lm_eval import evaluator
import json
from datasets import load_dataset
from utils.metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)
dataset2metric = {
    "narrativeqa": qa_f1_score,
    "qasper": qa_f1_score,
    "multifieldqa_en": qa_f1_score,
    "multifieldqa_zh": qa_f1_zh_score,
    "hotpotqa": qa_f1_score,
    "2wikimqa": qa_f1_score,
    "musique": qa_f1_score,
    "dureader": rouge_zh_score,
    "gov_report": rouge_score,
    "qmsum": rouge_score,
    "multi_news": rouge_score,
    "vcsum": rouge_zh_score,
    "trec": classification_score,
    "triviaqa": qa_f1_score,
    "samsum": rouge_score,
    "lsht": classification_score,
    "passage_retrieval_en": retrieval_score,
    "passage_count": count_score,
    "passage_retrieval_zh": retrieval_zh_score,
    "lcc": code_sim_score,
    "repobench-p": code_sim_score,
}

def get_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--project_name", type=str, default='kvlinc')
    parser.add_argument("--model_config", type=str, default=None)

    parser.add_argument("--pretrained_model_name_or_path", type=str, default=None)
    parser.add_argument("--load_checkpoint", type=str, default=None)
    

    parser.add_argument("--resq_rotation_path", type=str, default='./rotations/R.bin')

    
    # Evaluation
    parser.add_argument("--tasks", type=str, default="")
    parser.add_argument("--long_bench_tasks", type=str, default="")
    parser.add_argument("--apply_chat_template", action='store_true', default=None)


    # Miscellaneous
    parser.add_argument("--huggingface_token", type=str, default=None)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--verbose", action='store_true', default=None)

    args = parser.parse_args()
    return args

def get_local_rank() -> int:
    if os.environ.get("LOCAL_RANK"):
        return int(os.environ["LOCAL_RANK"])
    else:
        return torch.distributed.get_rank()

# This is the customized building prompt for chat models
def build_chat(tokenizer, prompt, model_name):
    # For results in KIVI paper (Llama, Llama-Chat, Mistral-7B-v0.1), we do not apply any special treatment to the prompt.
    # For lmsys/longchat-7b-v1.5-32k and mistralai/Mistral-7B-Instruct-v0.2, we need to rewrite the prompt a little bit.
    # Update: we add the template for the new llama-3-instruct model
    if "llama-3" in model_name.lower() and "instruct" in model_name.lower():
        messages = [
            {"role": "user", "content": prompt},
        ]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    elif "longchat" in model_name.lower():
        from fastchat.model import get_conversation_template
        conv = get_conversation_template("vicuna")
        conv.append_message(conv.roles[0], prompt)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
    elif "mistral-v0.2-instruct" in model_name.lower():
        messages = [
            {
                "role": "user",
                "content": prompt
            }
        ]
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

@torch.no_grad()
def scorer(dataset, predictions, answers, all_classes):
    total_score = 0.0
    for prediction, ground_truths in zip(predictions, answers):
        score = 0.0
        if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
            prediction = prediction.lstrip("\n").split("\n")[0]
        for ground_truth in ground_truths:
            score = max(
                score,
                dataset2metric[dataset](
                    prediction, ground_truth, all_classes=all_classes
                ),
            )
        total_score += score
    return round(100 * total_score / len(predictions), 2)

@torch.no_grad()
def get_pred(
    model,
    tokenizer,
    data,
    max_length,
    max_gen,
    prompt_format,
    dataset,
    device,
    model_name,
):
    preds = []
    for json_obj in tqdm(data):
        prompt = prompt_format.format(**json_obj)
        # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
        tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
        if len(tokenized_prompt) > max_length:
            half = int(max_length/2)
            prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
        
        if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
            prompt = build_chat(tokenizer, prompt, model_name)
        input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
        context_length = input.input_ids.shape[-1]
        if (
            dataset == "samsum"
        ):  # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
            output = model.generate(
                **input,
                max_new_tokens=max_gen,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
                min_length=context_length + 1,
                eos_token_id=[
                    tokenizer.eos_token_id,
                    tokenizer.encode("\n", add_special_tokens=False)[-1],
                ],
            )[0]
        else:
            output = model.generate(
                **input,
                max_new_tokens=max_gen,
                num_beams=1,
                do_sample=False,
                temperature=1.0,
            )[0]
        pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
        preds.append(
            {
                "pred": pred,
                "answers": json_obj["answers"],
                "all_classes": json_obj["all_classes"],
                "length": json_obj["length"],
            }
        )
    return preds
def main():
    # ------
    # SET UP
    # ------
    args = get_args()
    seed_everything(args.seed)
    args.device = torch.device('cuda')
    dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=8))
    local_rank = get_local_rank()

    print("the rank is {}".format(local_rank))
    torch.distributed.barrier()


    model_config_path = join('./configs/model', f'{args.model_config}.yaml')
    model_config = OmegaConf.load(model_config_path)
    model_config = update_model_config_from_args(model_config, args)
    
    print_header('Model Config')
    print_config(model_config)

    

    # Get pretrained model
    model_loader = get_pretrained_loader(**model_config.model,
                                         huggingface_token=args.huggingface_token)
    tokenizer = model_loader.load_tokenizer()
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = 'left'

    # Convert model
    args.attention_type = model_config['attention']['attention_type']
    model = model_loader.load(model_type=args.attention_type)
    model.config.kvquant = model_config.attention.kvquant

    if args.verbose:
        print_header('*** Initial Model ***')
        print(model)
    # --------
    # TRAINING
    # --------
    # Swap initial attentions if applicable
    model = load_and_convert_attns(model, model_config, 
                                    attention_type=args.attention_type, 
                                    checkpoint_path=args.load_checkpoint, 
                                    print_model=args.verbose,
                                    train_attention=False)
    
    torch.distributed.barrier()
    if hasattr(model_config.attention,'apply_rotations'):
        args.apply_rot = model_config.attention.apply_rotations
    else:
        args.apply_rot = False
    if args.apply_rot:
        model = add_rotations(model, args)
    
    model.cuda()
    model.eval()

    torch.distributed.barrier()
   
    if args.tasks != "":
        model.generation_config.max_new_tokens = None  # prevent the warning
        lm = HFLM(pretrained=model, tokenizer=tokenizer)
        lm._device = model.device
        model_args={}
        # model_args['parallelize'] = True if args.multigpu else False
        
        model_args['trust_remote_code'] = True
        if "ruler" in args.tasks:
            model_args['pretrained'] = model_config.model.pretrained_model_name_or_path
            model_args['max_seq_lengths'] = [4096,8192]
        import datasets
        datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True

        tasks = args.tasks.split(",")
        t_results = evaluator.simple_evaluate(
            lm,
            tasks=tasks,
            model_args=model_args,
            apply_chat_template=args.apply_chat_template,
        )
        print(make_table(t_results))

    if args.long_bench_tasks != "":
        # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
        dataset2prompt = json.load(open("configs/longbench/dataset2prompt.json", "r"))
        dataset2maxlen = json.load(open("configs/longbench/dataset2maxlen.json", "r"))
        model_name = args.model_config
        max_length = model_config.model.max_position_embeddings
        
        all_tasks = args.long_bench_tasks.split(",")
        # datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
            # "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
            # "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
        for dataset in all_tasks:
            data = load_dataset('THUDM/LongBench', f"{dataset}", split='test')
            if not os.path.exists(f"pred/{model_name}"):
                os.makedirs(f"pred/{model_name}")
            out_path = f"pred/{model_name}/{dataset}.jsonl"

            if os.path.exists(out_path):
                continue
            prompt_format = dataset2prompt[dataset]
            max_gen = dataset2maxlen[dataset]
            preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, "cuda", model_name)
            with open(out_path, "w", encoding="utf-8") as f:
                for pred in preds:
                    json.dump(pred, f, ensure_ascii=False)
                    f.write('\n')
            
        path = f"pred/{model_name}/"
        all_files = os.listdir(path)
        scores = dict()

        print("Evaluating on:", all_files)
        for filename in all_files:
            if not filename.endswith("jsonl"):
                continue
            predictions, answers, lengths = [], [], []
            dataset = filename.split('.')[0]
            with open(f"{path}{filename}", "r", encoding="utf-8") as f:
                for line in f:
                    data = json.loads(line)
                    predictions.append(data["pred"])
                    answers.append(data["answers"])
                    all_classes = data["all_classes"]
                    if "length" in data:
                        lengths.append(data["length"])
            score = scorer(dataset, predictions, answers, all_classes)
            scores[dataset] = score
        out_path = f"pred/{model_name}/result.json"
        with open(out_path, "w") as f:
            json.dump(scores, f, ensure_ascii=False, indent=4)
    # dataset_ppl = evaluator(model, testloader, "cuda", args, seqlen=2048)
    torch.distributed.barrier()
    model = model.cpu()
    torch.cuda.empty_cache()


if __name__ == '__main__':
    main()
