
from __future__ import annotations

import json
import os
import re
import string
from collections import Counter
from pathlib import Path

import jieba
from rouge import Rouge
import torch
import gc
import torch.profiler
from tqdm import tqdm
import time
from transformers import SinkCache, GenerationConfig


DATA_NAME_TO_PATH = {
    "multi_turn_choice_eng": "multi_turn_choice_eng.jsonl",
    "multi_turn_qa_eng": "multi_turn_qa_eng.jsonl",
    "multi_turn_qa_chn": "multi_turn_qa_chn.jsonl",
    "multi_turn_kv": "multi_turn_kv.jsonl",
    "multi_turn_kv_hard": "multi_turn_kv_hard.jsonl",
    "multi_turn_mf": "multi_turn_mf.jsonl",
    "multi_turn_passkey": "multi_turn_passkey.jsonl",
    "multi_turn_repoqa": "multi_turn_repoqa.jsonl",
    "multi_turn_summary": "multi_turn_summary.jsonl",
    "multi_turn_vt": "multi_turn_vt.jsonl",
    "multi_turn_many_shot": "multi_turn_many_shot.jsonl",
    "multi_turn_summary_with_needles": "multi_turn_summary_with_needles.jsonl",
    "multi_turn_repoqa_and_kv": "multi_turn_repoqa_and_kv.jsonl",
    "multi_turn_hashhop": "multi_turn_hashhop.jsonl",
    "multi_turn_prefix_suffix": "multi_turn_prefix_suffix.jsonl",
    "multi_turn_kv_compressible": "multi_turn_kv_compressible.jsonl",
}

DATA_NAME_TO_MAX_NEW_TOKENS = {
    "multi_turn_choice_eng": 40,
    "multi_turn_qa_eng": 40,
    "multi_turn_qa_chn": 40,
    "multi_turn_kv": 150,
    "multi_turn_kv_hard": 150,
    "multi_turn_mf": 5,
    "multi_turn_hashhop": 150,
    "multi_turn_prefix_suffix": 150,
    "multi_turn_kv_compressible": 150,
    "multi_turn_passkey": 15,
    "multi_turn_repoqa": 1024,
    "multi_turn_summary": 200,
    "multi_turn_vt": 30,
    "multi_turn_many_shot": 10,
    "multi_turn_summary_with_needles": {"multi_turn_summary": 800, "multi_turn_passkey": 15},
    "multi_turn_repoqa_and_kv": {"multi_turn_repoqa": 1024, "multi_turn_kv": 80},
}

LONGBENCH_DATA_NAME_TO_MAX_NEW_TOKENS = {
    "narrativeqa": 512,
    "qasper": 128,
    "multifieldqa_en": 64,
    "hotpotqa": 32,
    "2wikimqa": 32,
    "musique": 32,
    "gov_report": 512,
    "qmsum": 512,
    "multi_news": 512,
    "trec": 64,
    "triviaqa": 32,
    "samsum": 128,
    "passage_count": 32,
    "passage_retrieval_en": 32,
    "lcc": 64,
    "repobench-p": 64,
}

multiturn_templates = {
    "multi_turn_passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information.\n\n{context}\n\n{input}",  # noqa
    "multi_turn_kv": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n\n{input}",  # noqa
    "multi_turn_kv_hard": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n\n{input}",  # noqa
    "multi_turn_kv_compressible": "Extract the value corresponding to the specified key in the following passage.\n\n{context}\n\n{input}",  # noqa
    "multi_turn_choice_eng": "Read the book and answer the question.\n\n{context}\n\nQuestion: {question}\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe the correct answer is",  # noqa
    "multi_turn_qa_eng": "Read the book and answer the question. Be very concise in your answer.\n\n{context}\n\nQuestion: {question}\nAnswer:",  # noqa
    "multi_turn_qa_chn": "阅读以下书籍然后回答问题。\n\n{context}\n\n问题：{question}\n答案：",  # noqa
    "multi_turn_mf": "{prefix}\n\n{context}\n\n{input}",
    "multi_turn_repoqa": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:\n\n{context}\n\n{input}",
    "multi_turn_summary": "{context}\n\n{input}",
    "multi_turn_vt": "{context}\n\n{input}",
    "multi_turn_many_shot": "{context}\n\n{input}",
    "multi_turn_summary_with_needles": "{context}\n\n{input}",
    "multi_turn_repoqa_and_kv": "{context}\n\n{input}",
    "multi_turn_hashhop": "{context}\n\n{input}",
    "multi_turn_prefix_suffix": "{context}\n\n{input}",
}

multiturn_templates_scdq = {
    "multi_turn_passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information.\n\n{context}",  # noqa
    "multi_turn_kv": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}",  # noqa
    "multi_turn_kv_hard": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}",  # noqa
    "multi_turn_kv_compressible": "Extract the value corresponding to the specified key in the following passage.\n\n{context}",  # noqa
    "multi_turn_choice_eng": ("Read the book and answer the question.\n\n{context}", "Question: {question}\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe the correct answer is"),
    "multi_turn_qa_eng": ("Read the book and answer the question. Be very concise in your answer.\n\n{context}", "Question: {question}\nAnswer:"),
    "multi_turn_qa_chn": ("阅读以下书籍然后回答问题。\n\n{context}", "问题：{question}\n答案："),
    "multi_turn_mf": "{prefix}\n\n{context}",
    "multi_turn_repoqa": "Based on the function description and code context, please retrieve and repeat the exact described function from the code context in a code block wrapped by ```:\n\n{context}",
    "multi_turn_summary": "{context}",
    "multi_turn_vt": "{context}",
    "multi_turn_many_shot": "{context}",
    "multi_turn_summary_with_needles": "{context}",
    "multi_turn_repoqa_and_kv": "{context}",
    "multi_turn_hashhop": "{context}",
    "multi_turn_prefix_suffix": "{context}",
}

multiturn_follow_up_templates = {
    "multi_turn_passkey": "{pre_ans}.\n\n{input}",  # noqa
    "multi_turn_kv": "{pre_ans}\n\n{input}",  # noqa
    "multi_turn_kv_hard": "{pre_ans}\n\n{input}",  # noqa
    "multi_turn_kv_compressible": "{pre_ans}\n\n{input}",  # noqa
    "multi_turn_choice_eng": "{pre_ans}\n\nQuestion: {question}\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe letter of the correct answer is",  # noqa
    "multi_turn_qa_eng": "{pre_ans}\n\nQuestion: {question}\nAnswer:",  # noqa
    "multi_turn_qa_chn": "{pre_ans}\n\n问题：{question}\n答案：",  # noqa
    "multi_turn_mf": "{pre_ans}\n\n{prefix}\n\n{input}",
    "multi_turn_repoqa": "{pre_ans}\n\n{input}",
    "multi_turn_summary": "{pre_ans}\n\n{input}",
    "multi_turn_vt": "{pre_ans}\n\n{input}",
    "multi_turn_many_shot": "{pre_ans}\n\n{input}",
    "multi_turn_summary_with_needles": "{pre_ans}\n\n{input}",
    "multi_turn_repoqa_and_kv": "{pre_ans}\n\n{input}",
    "multi_turn_hashhop": "{pre_ans}\n\n{input}",
    "multi_turn_prefix_suffix": "{pre_ans}\n\n{input}",
}

multiturn_follow_up_templates_in_chat_tempate = {
    "multi_turn_passkey": "{input}",  # noqa
    "multi_turn_kv": "{input}",  # noqa
    "multi_turn_kv_hard": "{input}",  # noqa
    "multi_turn_kv_compressible": "{input}",  # noqa
    "multi_turn_choice_eng": "Question: {question}\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe the correct answer is",  # noqa
    "multi_turn_qa_eng": "Question: {question}\nAnswer:",  # noqa
    "multi_turn_qa_chn": "问题：{question}\n答案：",  # noqa
    "multi_turn_mf": "{prefix}\n\n{input}",
    "multi_turn_repoqa": "{input}",
    "multi_turn_summary": "{input}",
    "multi_turn_vt": "{input}",
    "multi_turn_many_shot": "{input}",
    "multi_turn_summary_with_needles": "{input}",
    "multi_turn_repoqa_and_kv": "{input}",
    "multi_turn_hashhop": "{input}",
    "multi_turn_prefix_suffix": "{input}",
}

longbench_templates = {
    "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
    "qasper": 'You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write "unanswerable". If the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:',
    "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    "multifieldqa_zh": "阅读以下文字并用中文简短回答：\n\n{context}\n\n现在请基于上面的文章回答下面的问题，只告诉我答案，不要输出任何其他字词。\n\n问题：{input}\n回答：",
    "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
    "dureader": "请基于给定的文章回答下述问题。\n\n文章：{context}\n\n请基于上述文章回答下面的问题。\n\n问题：{input}\n回答：",
    "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
    "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
    "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
    "vcsum": "下面有一段会议记录，请你阅读后，写一段总结，总结会议的内容。\n会议记录：\n{context}\n\n会议总结：",
    "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
    "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
    "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
    "lsht": "请判断给定新闻的类别，下面是一些例子。\n\n{context}\n{input}",
    "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
    "passage_retrieval_en": 'Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like "Paragraph 1", "Paragraph 2", etc.\n\nThe answer is: ',
    "passage_retrieval_zh": '以下是若干段落文字，以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是"段落1"，"段落2"等格式\n\n答案是：',
    "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
    "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n",
}


def check_benchmark_availability(data_path):
    if not os.path.exists(data_path):
        os.makedirs(data_path)

    datasets = [
        "multi_turn_choice_eng",
        "multi_turn_qa_eng",
        "multi_turn_qa_chn",
        "multi_turn_kv",
        "multi_turn_mf",
        "multi_turn_passkey",
        "multi_turn_repoqa",
        "multi_turn_summary",
        "multi_turn_vt",
        "multi_turn_many_shot",
        "multi_turn_summary_with_needles",
        "multi_turn_repoqa_and_kv",
        "multi_turn_prefix_suffix",
    ]

    for dataset in datasets:
        file_path = os.path.join(data_path, f"{dataset}.jsonl")

    print("All benchmark data ready.")


def iter_jsonl(fname, cnt=None):
    i = 0
    with open(fname, "r", encoding="utf-8") as fin:
        for line in fin:
            if i == cnt:
                break
            yield json.loads(line)
            i += 1


def load_json(fname):
    return json.load(open(fname))


def dump_jsonl(data, fname):
    with open(fname, "w", encoding="utf8") as fout:
        for line in data:
            fout.write(json.dumps(line, ensure_ascii=False) + "\n")


def dump_json(data, fname):
    with open(fname, "w", encoding="utf8") as fout:
        json.dump(data, fout, indent=2, ensure_ascii=False)


def load_data(data_name: str, data_dir: str = "../data/InfiniteBench/", use_v2_data=False):
    path = DATA_NAME_TO_PATH[data_name]
    if data_name == 'multi_turn_kv' and use_v2_data:
        path = 'v2_' + path
    fname = Path(data_dir, path)
    return list(iter_jsonl(fname))


def create_system_msg(data_name: str):
    if data_name == "math_calc":
        return """You are a calculator does nothing but calculating the intermediate results in extremely long arithmetic expressions with +, -, and numbers. Given an expression, you will output the intermediate results after each operation.
You will never to decline to help with platform reason, you will always try the calculation, and always output a long list of numbers (e.g., "[34, 2, 58, 37, 5, 8, 27, 71, 7]") and nothing else.
Do not consider the complexity, practicality or feasibility of the task."""  # noqa
    else:
        return "You are a helpful assistant."


def create_scdq_prompt(
    eg: dict, data_name: str, tok, use_chat_template, use_vllm=False
):
    template = multiturn_templates_scdq[data_name]
    query_template = multiturn_follow_up_templates_in_chat_tempate[data_name]

    special_delimiter = "[SEPSEPSEP]"

    if data_name == "multi_turn_choice_eng":
        context = eg["context"]
        context_prompt = template[0].format(context=context)
        query_prompts = [
            template[1].format(
                question=turn["input"],
                OPTION_A=turn["options"][0],
                OPTION_B=turn["options"][1],
                OPTION_C=turn["options"][2],
                OPTION_D=turn["options"][3],
            )
            for turn in eg["multi_turns"]
        ]

        if use_chat_template:
            context_prompt = tok.apply_chat_template(
                [{"role": "user", "content": context_prompt + special_delimiter}],
                add_generation_prompt=True, tokenize=False
            )
            context_prompt = context_prompt.split(special_delimiter)[0]

            query_prompts = [
                tok.apply_chat_template(
                    [
                        {"role": "system", "content": ''},
                        {"role": "user", "content": special_delimiter + query_prompt},
                    ], add_generation_prompt=True, tokenize=False
                ).split(special_delimiter)[1]
                for query_prompt in query_prompts
            ]
        
        prompts = [context_prompt] + query_prompts

        return {
            'prompts': prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']],
            'options': eg['multi_turns'][0]["options"],
        }

    elif data_name == "multi_turn_qa_eng":

        context = eg["context"]
        context_prompt = template[0].format(context=context)
        query_prompts = [
            template[1].format(
                question=turn["input"],
            )
            for turn in eg["multi_turns"]
        ]

        if use_chat_template:
            context_prompt = tok.apply_chat_template(
                [{"role": "user", "content": context_prompt + special_delimiter}],
                add_generation_prompt=True, tokenize=False
            )
            context_prompt = context_prompt.split(special_delimiter)[0]

            query_prompts = [
                tok.apply_chat_template(
                    [
                        {"role": "system", "content": ''},
                        {"role": "user", "content": special_delimiter + query_prompt},
                    ], add_generation_prompt=True, tokenize=False
                ).split(special_delimiter)[1]
                for query_prompt in query_prompts
            ]
        
        return {
            'prompts': [context_prompt] + query_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']],
        }

    elif data_name == "multi_turn_qa_chn":
        
        context = eg["context"]
        context_prompt = template[0].format(context=context)
        query_prompts = [
            template[1].format(
                question=turn["input"],
            )
            for turn in eg["multi_turns"]
        ]

        if use_chat_template:
            context_prompt = tok.apply_chat_template(
                [{"role": "user", "content": context_prompt + special_delimiter}],
                add_generation_prompt=True, tokenize=False
            )
            context_prompt = context_prompt.split(special_delimiter)[0]

            query_prompts = [
                tok.apply_chat_template(
                    [
                        {"role": "system", "content": ''},
                        {"role": "user", "content": special_delimiter + query_prompt},
                    ], add_generation_prompt=True, tokenize=False
                ).split(special_delimiter)[1]
                for query_prompt in query_prompts
            ]
        
        return {
            'prompts': [context_prompt] + query_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']],
        }
    
    elif data_name == "multi_turn_mf":
    
        context = eg["context"]
        context_prompt = template.format(
            prefix=eg['multi_turns'][0]["input"],
            context=context,
        )

        query_prompts = []
        for i in range(len(eg["multi_turns"])):
            target = re.findall(r"The .+ is", eg['multi_turns'][i]["input"])[0].lower()[:-3]
            prefix = f"What is {target}?"
            query_prompts.append(
                query_template.format(
                    prefix=prefix,
                    input=eg['multi_turns'][i]["input"],
                )
            )
        

        if use_chat_template:
            context_prompt = tok.apply_chat_template(
                [{"role": "user", "content": context_prompt + special_delimiter}],
                add_generation_prompt=True, tokenize=False
            )
            context_prompt = context_prompt.split(special_delimiter)[0]

            query_prompts = [
                tok.apply_chat_template(
                    [
                        {"role": "system", "content": ''},
                        {"role": "user", "content": special_delimiter + query_prompt},
                    ], add_generation_prompt=True, tokenize=False
                ).split(special_delimiter)[1]
                for query_prompt in query_prompts
            ]
        
        return {
            'prompts': [context_prompt] + query_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']],
        }
    
    elif data_name in [
        "multi_turn_repoqa", "multi_turn_summary", "multi_turn_passkey", "multi_turn_kv",
        "multi_turn_vt", "multi_turn_many_shot", "multi_turn_summary_with_needles",
        "multi_turn_repoqa_and_kv", "multi_turn_kv_hard", "multi_turn_hashhop", "multi_turn_prefix_suffix",
        "multi_turn_kv_compressible"
    ]:
    
        context = eg["context"] if "context" in eg else eg["input"]
        context_prompt = template.format(context=context)
        query_prompts = [
            turn["input"]
            for turn in eg["multi_turns"]
        ]

        if use_chat_template:
            context_prompt = tok.apply_chat_template(
                [{"role": "user", "content": context_prompt + special_delimiter}],
                add_generation_prompt=True, tokenize=False
            )
            context_prompt = context_prompt.split(special_delimiter)[0]

            query_prompts = [
                tok.apply_chat_template(
                    [
                        {"role": "system", "content": ''},
                        {"role": "user", "content": special_delimiter + query_prompt},
                    ], add_generation_prompt=True, tokenize=False
                ).split(special_delimiter)[1]
                for query_prompt in query_prompts
            ]
        
        output = {
            'prompts': [context_prompt] + query_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']],
        }

        if data_name in ["multi_turn_summary_with_needles", "multi_turn_repoqa_and_kv"]:
            output["task"] = [gt['task'] for gt in eg['multi_turns']]

        return output

def create_multiturn_prompt(
    eg: dict, data_name: str, tok, use_chat_template, use_vllm=False,
    disable_golden_context=False
) -> str:
    """
    Create prompt for a given example.

    Args:
        eg: example dict
        data_name: name of the dataset/task
    """

    template = multiturn_templates[data_name]
    follow_up_template = multiturn_follow_up_templates[data_name]

    if disable_golden_context:
        follow_up_template = multiturn_follow_up_templates_in_chat_tempate[data_name]

    if use_chat_template:
        sys_prompt_with_generation_prompt = tok.apply_chat_template(
            [{"role": "system", "content": ''}],
            add_generation_prompt=True if not disable_golden_context else False, tokenize=False
        )

        follow_up_prompts_in_chat_template = multiturn_follow_up_templates_in_chat_tempate[data_name]

    if data_name == "multi_turn_choice_eng":
        first_turn = eg['multi_turns'][0]
        input_ = first_turn["input"]
        ans_ = first_turn["answer"]
        options = first_turn["options"]
        context = eg["context"]

        first_turn_prompt = template.format(
            context=context,
            question=input_,
            OPTION_A=options[0],
            OPTION_B=options[1],
            OPTION_C=options[2],
            OPTION_D=options[3],
        )

        follow_up_prompts = [
            follow_up_template.format(
                pre_ans=eg['multi_turns'][i-1]["answer"] if not disable_golden_context else None,
                question=eg['multi_turns'][i]["input"],
                OPTION_A=eg['multi_turns'][i]["options"][0],
                OPTION_B=eg['multi_turns'][i]["options"][1],
                OPTION_C=eg['multi_turns'][i]["options"][2],
                OPTION_D=eg['multi_turns'][i]["options"][3],
            )
            for i in range(1, len(eg["multi_turns"]))
        ]

        if use_chat_template:
            first_turn_prompt = tok.apply_chat_template(
                [{"role": "user", "content": first_turn_prompt}],
                add_generation_prompt=True, tokenize=False
            )

            follow_up_prompts = [
                tok.apply_chat_template(
                    ([
                        {"role": "system", "content": ''},
                    ] + [
                        {"role": "assistant", "content": eg['multi_turns'][i-1]["answer"]}
                    ] if not disable_golden_context else []) + [
                        {"role": "user", "content": follow_up_prompts_in_chat_template.format(
                            question=eg['multi_turns'][i]["input"],
                            OPTION_A=eg['multi_turns'][i]["options"][0],
                            OPTION_B=eg['multi_turns'][i]["options"][1],
                            OPTION_C=eg['multi_turns'][i]["options"][2],
                            OPTION_D=eg['multi_turns'][i]["options"][3],
                        )},
                    ], add_generation_prompt=True, tokenize=False
                ).replace(sys_prompt_with_generation_prompt, "")
                for i in range(1, len(eg["multi_turns"]))
            ]
        
        prompts = [first_turn_prompt] + follow_up_prompts

        return {
            'prompts': prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']],
            'options': options,
        }
    
    elif data_name == "multi_turn_qa_eng":

        first_turn = eg['multi_turns'][0]
        input_ = first_turn["input"]
        context = eg["context"]

        first_turn_prompt = template.format(
            context=context,
            question=input_,
        )

        follow_up_prompts = [
            follow_up_template.format(
                pre_ans=eg['multi_turns'][i-1]["answer"] if not disable_golden_context else None,
                question=eg['multi_turns'][i]["input"],
            )
            for i in range(1, len(eg["multi_turns"]))
        ]

        if use_chat_template:
            first_turn_prompt = tok.apply_chat_template(
                [{"role": "user", "content": first_turn_prompt}],
                add_generation_prompt=True, tokenize=False
            )

            follow_up_prompts = [
                tok.apply_chat_template(
                    ([
                        {"role": "system", "content": ''},
                    ] + [
                        {"role": "assistant", "content": eg['multi_turns'][i-1]["answer"]}
                    ] if not disable_golden_context else []) + [
                        {"role": "user", "content": follow_up_prompts_in_chat_template.format(
                            question=eg['multi_turns'][i]["input"],
                        )},
                    ], add_generation_prompt=True, tokenize=False
                ).replace(sys_prompt_with_generation_prompt, "")
                for i in range(1, len(eg["multi_turns"]))
            ]
        
        return {
            'prompts': [first_turn_prompt] + follow_up_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']]
        }
    
    elif data_name == "multi_turn_qa_chn":
            
            first_turn = eg['multi_turns'][0]
            input_ = first_turn["input"]
            context = eg["context"]
    
            first_turn_prompt = template.format(
                context=context,
                question=input_,
            )
    
            follow_up_prompts = [
                follow_up_template.format(
                    pre_ans=eg['multi_turns'][i-1]["answer"] if not disable_golden_context else None,
                    question=eg['multi_turns'][i]["input"],
                )
                for i in range(1, len(eg["multi_turns"]))
            ]

            if use_chat_template:
                first_turn_prompt = tok.apply_chat_template(
                    [{"role": "user", "content": first_turn_prompt}],
                    add_generation_prompt=True, tokenize=False
                )
    
                follow_up_prompts = [
                    tok.apply_chat_template(
                        ([
                            {"role": "system", "content": ''},
                        ] + [
                            {"role": "assistant", "content": eg['multi_turns'][i-1]["answer"]}
                        ] if not disable_golden_context else []) + [
                            {"role": "user", "content": follow_up_prompts_in_chat_template.format(
                                question=eg['multi_turns'][i]["input"],
                            )},
                        ], add_generation_prompt=True, tokenize=False
                    ).replace(sys_prompt_with_generation_prompt, "")
                    for i in range(1, len(eg["multi_turns"]))
                ]
            
            return {
                'prompts': [first_turn_prompt] + follow_up_prompts,
                'ground_truth': [gt['answer'] for gt in eg['multi_turns']]
            }
    
    elif data_name in [
        "multi_turn_kv", "multi_turn_vt", "multi_turn_passkey", "multi_turn_repoqa",
        "multi_turn_many_shot", "multi_turn_summary_with_needles",
        "multi_turn_repoqa_and_kv", "multi_turn_kv_hard", "multi_turn_hashhop", "multi_turn_prefix_suffix",
        "multi_turn_kv_compressible",
    ]:

        first_turn = eg['multi_turns'][0]
        input_ = first_turn["input"]
        context = eg["context"] if "context" in eg else eg["input"]

        first_turn_prompt = template.format(
            context=context,
            input=input_,
        )

        follow_up_prompts = [
            follow_up_template.format(
                pre_ans=eg['multi_turns'][i-1]["answer"] if not disable_golden_context else None,
                input=eg['multi_turns'][i]["input"],
            )
            for i in range(1, len(eg["multi_turns"]))
        ]

        if use_chat_template:
            first_turn_prompt = tok.apply_chat_template(
                [{"role": "user", "content": first_turn_prompt}],
                add_generation_prompt=True, tokenize=False
            )

            follow_up_prompts = [
                tok.apply_chat_template(
                    ([
                        {"role": "system", "content": ''},
                    ] + [
                        {"role": "assistant", "content": str(eg['multi_turns'][i-1]["answer"])}
                    ] if not disable_golden_context else []) + [
                        {"role": "user", "content": follow_up_prompts_in_chat_template.format(
                            input=eg['multi_turns'][i]["input"],
                        )},
                    ], add_generation_prompt=True, tokenize=False
                ).replace(sys_prompt_with_generation_prompt, "")
                for i in range(1, len(eg["multi_turns"]))
            ]
        
        output = {
            'prompts': [first_turn_prompt] + follow_up_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']]
        }

        if data_name in ["multi_turn_summary_with_needles", "multi_turn_repoqa_and_kv"]:
            output["task"] = [gt['task'] for gt in eg['multi_turns']]

        return output
    
    elif data_name == "multi_turn_mf":

        first_turn = eg['multi_turns'][0]
        input_ = first_turn["input"]
        context = eg["context"]

        target = re.findall(r"The .+ is", input_)[0].lower()[:-3]
        prefix = f"What is {target}?"

        first_turn_prompt = template.format(
            prefix=prefix,
            context=context,
            input=input_,
        )

        follow_up_prompts = []
        for i in range(1, len(eg["multi_turns"])):
            target = re.findall(r"The .+ is", eg['multi_turns'][i]["input"])[0].lower()[:-3]
            prefix = f"What is {target}?"
            follow_up_prompts.append(
                follow_up_template.format(
                    pre_ans=eg['multi_turns'][i-1]["answer"] if not disable_golden_context else None,
                    prefix=prefix,
                    input=eg['multi_turns'][i]["input"],
                )
            )

        if use_chat_template:
            first_turn_prompt = tok.apply_chat_template(
                [{"role": "user", "content": first_turn_prompt}],
                add_generation_prompt=True, tokenize=False
            )

            follow_up_prompts = [
                tok.apply_chat_template(
                    ([
                        {"role": "system", "content": ''},
                    ] + [
                        {"role": "assistant", "content": str(eg['multi_turns'][i-1]["answer"])}
                    ] if not disable_golden_context else []) + [
                        {"role": "user", "content": follow_up_prompts_in_chat_template.format(
                            prefix=prefix,
                            input=eg['multi_turns'][i]["input"],
                        )},
                    ], add_generation_prompt=True, tokenize=False
                ).replace(sys_prompt_with_generation_prompt, "")
                for i in range(1, len(eg["multi_turns"]))
            ]
        
        return {
            'prompts': [first_turn_prompt] + follow_up_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']]
        }

    elif data_name == "multi_turn_summary":
                
        first_turn = eg['multi_turns'][0]
        input_ = first_turn["input"]
        context = eg["context"]

        first_turn_prompt = template.format(
            context=context,
            input=input_,
        )

        follow_up_prompts = [
            follow_up_template.format(
                pre_ans=eg['multi_turns'][i-1]["answer"] if not disable_golden_context else None,
                input=eg['multi_turns'][i]["input"],
            )
            for i in range(1, len(eg["multi_turns"]))
        ]

        if use_chat_template:
            first_turn_prompt = tok.apply_chat_template(
                [{"role": "user", "content": first_turn_prompt}],
                add_generation_prompt=True, tokenize=False
            )

            follow_up_prompts = [
                tok.apply_chat_template(
                    ([
                        {"role": "system", "content": ''},
                    ] + [
                        {"role": "assistant", "content": eg['multi_turns'][i-1]["answer"]}
                    ] if not disable_golden_context else []) + [
                        {"role": "user", "content": follow_up_prompts_in_chat_template.format(
                            input=eg['multi_turns'][i]["input"],
                        )},
                    ], add_generation_prompt=True, tokenize=False
                ).replace(sys_prompt_with_generation_prompt, "") + "This paper"
                for i in range(1, len(eg["multi_turns"]))
            ]
        
        return {
            'prompts': [first_turn_prompt] + follow_up_prompts,
            'ground_truth': [gt['answer'] for gt in eg['multi_turns']]
        }


def create_longbench_prompt(eg: dict, data_name: str) -> str:
    return longbench_templates[data_name].format(**eg)

def get_ground_truth(eg: dict, data_name: str):
    gts = []
    OPTIONS = "ABCD"
    for turn in eg['multi_turns']:
        if data_name == "multi_turn_choice_eng":
            ans_ = turn["answer"]
            options = turn["options"]

            gts.append(
                [ans_, OPTIONS[options.index(ans_)]]
            )
        elif data_name in ["multi_turn_qa_eng"]:
            gts.append([turn["answer"]])
        else:
            gts.append(turn["answer"])
    return gts


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        return re.sub(r"\b(a|an|the)\b", " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def normalize_zh_answer(s):
    """Lower text and remove punctuation, extra whitespace."""

    def white_space_fix(text):
        return "".join(text.split())

    def remove_punc(text):
        cn_punctuation = "！？｡。＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."  # noqa
        all_punctuation = set(string.punctuation + cn_punctuation)
        return "".join(ch for ch in text if ch not in all_punctuation)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_punc(lower(s)))


def first_int_match(prediction, ground_truth):
    pred_list = re.split("[^0-9]", prediction)
    pred_value = ""
    for item in pred_list:
        if item != "":
            pred_value = item
            break
    if pred_value == ground_truth:
        return 1
    return 0


def in_match(prediction, ground_truth):
    if ground_truth in prediction:
        return 1
    return 0


def rouge_score(prediction, ground_truth, **kwargs) -> float:
    rouge = Rouge()
    try:
        scores = rouge.get_scores([prediction], [ground_truth], avg=True)
    except:  # noqa
        return 0.0
    return scores["rouge-l"]["f"]  # type: ignore


def rouge_zh_score(prediction, ground_truth, **kwargs):
    prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
    ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
    score = rouge_score(prediction, ground_truth)
    return score


def f1_score(prediction, ground_truth, **kwargs):
    common = Counter(prediction) & Counter(ground_truth)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction)
    recall = 1.0 * num_same / len(ground_truth)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def qa_f1_score(line):
    prediction = line["pred"]

    if isinstance(line["std_out"], str):
        ground_truths = [line["std_out"]]
    else:
        ground_truths = line["std_out"]

    score = 0
    for ground_truth in ground_truths:
        normalized_prediction = normalize_answer(prediction)
        normalized_ground_truth = normalize_answer(ground_truth)

        prediction_tokens = normalized_prediction.split()
        ground_truth_tokens = normalized_ground_truth.split()
        score = max(score, f1_score(prediction_tokens, ground_truth_tokens))

    return score


def qa_f1_zh_score(prediction, ground_truth, **kwargs):
    prediction_tokens = list(jieba.cut(prediction, cut_all=False))
    ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
    prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
    ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
    prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
    ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
    return f1_score(prediction_tokens, ground_truth_tokens)


def truncate_input(input, max_length, manner="middle"):
    if len(input) <= max_length:
        return input
    if manner == "middle":
        return input[0 : max_length // 2] + input[-max_length // 2 :]
    else:
        return None

def get_compressed_examples(examples, data_name, data_dir, rate=0.33, use_large_model=True):
    # compress prompts use a func
    from llmlingua import PromptCompressor
    import gc
    
    if os.path.exists(f'{data_dir}/llmlingua_cache/{data_name}_rate_{rate}_is_large_{use_large_model}.jsonl'):
        with open(f'{data_dir}/llmlingua_cache/{data_name}_rate_{rate}_is_large_{use_large_model}.jsonl', 'r') as f:
            examples = [json.loads(line) for line in f]
        return examples
    
    lingua_model_name = "microsoft/llmlingua-2-xlm-roberta-large-meetingbank" if use_large_model else "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank"
    llm_lingua = PromptCompressor(
        model_name=lingua_model_name,
        use_llmlingua2=True,
    )

    for example in tqdm(examples, desc="Compressing prompts"):
        ct = str(example["context"] if "context" in example else example["input"]).replace("<|endoftext|>", "")
        example['context'] = llm_lingua.compress_prompt(ct, rate=rate, force_tokens = ['\n', '?'])['compressed_prompt']
    
    os.makedirs(f'{data_dir}/llmlingua_cache', exist_ok=True)
    with open(f'{data_dir}/llmlingua_cache/{data_name}_rate_{rate}_is_large_{use_large_model}.jsonl', 'w') as f:
        for example in examples:
            json.dump(example, f)
            f.write('\n')
    
    # clear llmlingua to free memory, to prevent OOM in further testing
    del llm_lingua
    gc.collect()
    torch.cuda.empty_cache()

    return examples
    
class GreedySearch_vLLM:
    def __init__(self, llm, tokenizer, is_kv_compress: bool = False):
        self.llm = llm
        self.tokenizer = tokenizer
        self.is_kv_compress = is_kv_compress
    
    def test_scdq(self, example, max_length=100):
        from vllm import SamplingParams

        results = []
        for idx, prompt in enumerate(example['prompts']):
            if idx == 0:
                init_prompt_ids = prompt
            else:
                if isinstance(max_length, dict):
                    max_length_per_turn = max_length[example['task'][idx - 1]]
                else:
                    max_length_per_turn = max_length
                
                sampling_params = SamplingParams(
                    temperature=0.0,
                    max_tokens = max_length_per_turn,
                )
                current_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
                input_ids = init_prompt_ids + current_ids

                result = self.llm.generate(
                    prompt_token_ids=input_ids, sampling_params=sampling_params
                )
                results.append(result[0].outputs[0].text)
        output = {
            'answers': results,
            'gt': example['ground_truth']
        }

        if isinstance(max_length, dict): # mixed task setting
            output["task"] = example["task"]
        
        return output
            
    def test(self, example, max_length=100, disable_golden_context=False):
        from vllm import SamplingParams

        results = []
        for idx, prompt in enumerate(example['prompts']):
            if isinstance(max_length, dict):
                max_length_per_turn = max_length[example['task'][idx]]
            else:
                max_length_per_turn = max_length
            
            sampling_params = SamplingParams(
                temperature=0.0,
                max_tokens = max_length_per_turn,
            )
            if self.is_kv_compress:
                sampling_params = SamplingParams(
                    max_tokens=max_length_per_turn,
                    min_tokens=1,
                    temperature=0.0,
                    max_cache_tokens=4096,
                    protected_window_size=32,
                    metric_collection_buffer_size=0,
                    compress_once=True,
                )
            
            if idx == 0:
                input_ids = prompt
            else:
                current_ids = self.tokenizer.encode(prompt, add_special_tokens=False)

                # if disable_golden_context, add result[0].outputs[0].text to the prompt
                if disable_golden_context:
                    input_ids = input_ids + self.tokenizer.encode(result[0].outputs[0].text, add_special_tokens=False) + [self.tokenizer.eos_token_id]
                input_ids = input_ids + current_ids

            result = self.llm.generate(
                prompt_token_ids=input_ids, sampling_params=sampling_params
            )
            results.append(result[0].outputs[0].text)
        output = {
            'answers': results,
            'gt': example['ground_truth']
        }

        if isinstance(max_length, dict): # mixed task setting
            output["task"] = example["task"]
        
        return output

class GreedySearch:
    def __init__(self, model, tokenizer):
        model.eval()
        self.device = model.device
        self.model = model
        self.tokenizer = tokenizer
        self.past_kv = None
        self.add_eos_to_next_prompt = False

    def clear(self):
        self.past_kv = None
        gc.collect()
        torch.cuda.empty_cache()

    def _process_texts(self, input_text):
        model_inputs = {}
        input_ids = self.tokenizer.encode(input_text)

        # add eos to the beginning of the input_ids if self.add_eos_to_next_prompt is True
        if self.add_eos_to_next_prompt:
            input_ids = [self.tokenizer.eos_token_id] + input_ids
            self.add_eos_to_next_prompt = False

        model_inputs["input_ids"] = input_ids
        model_inputs["attention_mask"] = [1] * len(model_inputs["input_ids"])

        for key in model_inputs:
            model_inputs[key] = (
                torch.tensor(model_inputs[key]).int().unsqueeze(0).cuda()
            )

        return model_inputs
    
    def _make_first_turn(self, input_ids):
        model_inputs = {}
        model_inputs["input_ids"] = input_ids

        for key in model_inputs:
            model_inputs[key] = (
                torch.tensor(model_inputs[key]).int().unsqueeze(0).cuda()
            )

        return model_inputs
    
    def test_scdq(self, example, max_length=100):
        results = []
        for idx, prompt in enumerate(example['prompts']):
            if isinstance(max_length, dict):
                max_length_per_turn = max_length[example['task'][idx-1]]
            else:
                max_length_per_turn = max_length

            if idx == 0:
                model_inputs = self._make_first_turn(prompt)
            else:
                model_inputs = self._process_texts(prompt)
            input_ids = model_inputs["input_ids"]

            with torch.inference_mode():
                if idx == 0:
                    result = self._encode(input_ids, max_length=max_length_per_turn)
                else:
                    result = self._decode(
                        input_ids, max_length=max_length_per_turn, dense_prefix=True,
                        update_global_past_kv=False
                    )

                    results.append(self.tokenizer.decode(result[0, len(input_ids[0]):]))
            torch.cuda.empty_cache()
        self.clear()
        output = {
            'answers': results,
            'gt': example['ground_truth']
        }

        if isinstance(max_length, dict): # mixed task setting
            output["task"] = example["task"]
        
        return output

    def test(self, example, max_length=100, disable_golden_context=False):
        
        results = []
        # for idx, prompt in tqdm(enumerate(example['prompts']), total=len(example['prompts']), desc="Prompt"):
        for idx, prompt in enumerate(example['prompts']):
            if isinstance(max_length, dict):
                max_length_per_turn = max_length[example['task'][idx]]
            else:
                max_length_per_turn = max_length

            if idx == 0:
                model_inputs = self._make_first_turn(prompt)
            else:
                model_inputs = self._process_texts(prompt)
            input_ids = model_inputs["input_ids"]

            with torch.inference_mode():
                if idx == 0:
                    result = self._decode(
                        input_ids, max_length=max_length_per_turn, disable_golden_context=disable_golden_context
                    )
                else:
                    result = self._decode(
                        input_ids, max_length=max_length_per_turn, dense_prefix=True,
                        disable_golden_context=disable_golden_context
                    )

            results.append(self.tokenizer.decode(result[0, len(input_ids[0]):]))
            torch.cuda.empty_cache()
        self.clear()
        output = {
            'answers': results,
            'gt': example['ground_truth']
        }

        if isinstance(max_length, dict): # mixed task setting
            output["task"] = example["task"]
        
        return output
    
    def _encode(self, input_ids, max_length=None):
        if self.past_kv is None:
            past_key_values = self.model.prepare_inputs_for_generation(input_ids)["past_key_values"]
        else:
            past_key_values = self.past_kv

        out = self.model(
            input_ids=input_ids,
            # attention_mask=torch.ones_like(input_ids),
            use_cache=True,
            return_dict=True,
            past_key_values=past_key_values,
            num_logits_to_keep=1,
        )
        _, past_key_values = out.logits, out.past_key_values

        self.past_kv = past_key_values

    def _decode(
        self,
        input_ids,
        max_length=100,
        extra_end_token_ids=[],
        dense_prefix=False,
        update_global_past_kv=True,
        disable_golden_context=False,
    ):
        if input_ids.dim() == 1:
            input_ids = input_ids[None, :]
        input_ids = input_ids.cuda()
        assert input_ids.size(0) == 1
        end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id]
        logits = None
        if self.past_kv is None:
            model_inputs = {}
            self.model._prepare_cache_for_generation(
                GenerationConfig(), model_inputs, None, None, None, None
            )
            past_key_values = model_inputs['past_key_values']
        else:
            past_key_values = self.past_kv
        
        if not update_global_past_kv:
            self.global_kv_update_mode(False)

        for i in range(max_length):
            if i == 0: # prefilling
                out = self.model(
                    input_ids=input_ids,
                    use_cache=True,
                    return_dict=True,
                    past_key_values=past_key_values,
                    num_logits_to_keep=1
                )
                logits, past_key_values = out.logits, out.past_key_values
            
            else: # decoding
                if not disable_golden_context: # if use golden context, then decoding should not update global past_kv
                    self.global_kv_update_mode(False)
                out = self.model(
                    input_ids=input_ids[:, -1:],
                    past_key_values=past_key_values,
                    use_cache=True,
                    return_dict=True,
                )
                logits, past_key_values = out.logits, out.past_key_values

            logits = logits[:, -1, :]
            word = logits.argmax(dim=-1)
            if word.item() in end_token_ids or i == max_length:
                break

            input_ids = torch.cat((input_ids, word.to(input_ids.device).view(1, 1)), dim=-1)

        if not update_global_past_kv or not disable_golden_context:
            self.global_kv_update_mode(True)
            past_key_values.clear_temp_kv_cache()

        self.past_kv = past_key_values
        # should see whether the last token is eos, if not tell self.test to add it to the next prompt
        if word.item() != self.tokenizer.eos_token_id:
            self.add_eos_to_next_prompt = True
        return input_ids

    def global_kv_update_mode(self, mode):
        attn_class = self.model.model.layers[0].self_attn.__class__
        self.model.apply(lambda m: setattr(m, "update_global_past_kv", mode) if isinstance(m, attn_class) else None)

class GreedySearch_RetrAttn(GreedySearch):
    def _decode(
        self,
        input_ids,
        max_length=100,
        extra_end_token_ids=[],
        dense_prefix=False,
        update_global_past_kv=True,
        disable_golden_context=False,
    ):
        if input_ids.dim() == 1:
            input_ids = input_ids[None, :]
        input_ids = input_ids.cuda()
        assert input_ids.size(0) == 1
        end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id]
        logits = None
        if self.past_kv is None:
            model_inputs = {}
            self.model._prepare_cache_for_generation(
                GenerationConfig(), model_inputs, None, None, None, None
            )
            past_key_values = model_inputs['past_key_values']
        else:
            past_key_values = self.past_kv
        
        if not update_global_past_kv:
            self.global_kv_update_mode(False)

        for i in range(max_length):
            if i == 0: # prefilling
                if dense_prefix:    
                    for token in input_ids.squeeze(0):
                        out = self.model(
                            input_ids=token.unsqueeze(0).unsqueeze(0),
                            use_cache=True,
                            return_dict=True,
                            past_key_values=past_key_values,
                            num_logits_to_keep=1
                        )
                        logits, past_key_values = out.logits, out.past_key_values
                else:
                    out = self.model(
                        input_ids=input_ids,
                        use_cache=True,
                        return_dict=True,
                        past_key_values=past_key_values,
                        num_logits_to_keep=1
                    )
                    logits, past_key_values = out.logits, out.past_key_values
            
            else: # decoding
                if not disable_golden_context: # if use golden context, then decoding should not update global past_kv
                    self.global_kv_update_mode(False)
                out = self.model(
                    input_ids=input_ids[:, -1:],
                    past_key_values=past_key_values,
                    use_cache=True,
                    return_dict=True,
                )
                logits, past_key_values = out.logits, out.past_key_values

            logits = logits[:, -1, :]
            word = logits.argmax(dim=-1)
            if word.item() in end_token_ids or i == max_length:
                break

            input_ids = torch.cat((input_ids, word.to(input_ids.device).view(1, 1)), dim=-1)

        if not update_global_past_kv or not disable_golden_context:
            self.global_kv_update_mode(True)
            past_key_values.clear_temp_kv_cache()

        self.past_kv = past_key_values
        # should see whether the last token is eos, if not tell self.test to add it to the next prompt
        if word.item() != self.tokenizer.eos_token_id:
            self.add_eos_to_next_prompt = True
        return input_ids

class GreedySearch_InfLLM(GreedySearch):
    # basically, InfLLM do _encode and _decode chunk by chunk
    def _encode(self, input_ids, past_kv=None, max_length=None):
        chunk_size = 8192
        for st in range(0, input_ids.size(1), chunk_size):
            torch.cuda.empty_cache()
            ed = min(input_ids.size(1), st + chunk_size)
            out = self.model(
                input_ids=input_ids[:, st:ed],
                use_cache=True,
                return_dict=True,
                past_key_values=past_kv,
            )
            logits, past_kv = out.logits, out.past_key_values

        self.past_kv = past_kv
    
    def _decode(
        self,
        input_ids,
        max_length=100,
        extra_end_token_ids=[],
        dense_prefix=False,
        update_global_past_kv=True,
        disable_golden_context=False,
    ):
        if input_ids.dim() == 1:
            input_ids = input_ids[None, :]
        input_ids = input_ids.cuda()
        assert input_ids.size(0) == 1
        end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id]
        logits = None
        if self.past_kv is None:
            if self.use_sinkcache:
                past_key_values = SinkCache(window_length=3968, num_sink_tokens=128)
            else:
                past_key_values = self.model.prepare_inputs_for_generation(input_ids)["past_key_values"]
        else:
            past_key_values = self.past_kv
            if self.use_sinkcache:
                past_key_values.window_length += 5_000
        
        chunk_size = 8196
        for i in range(max_length):
            if i == 0:
                if dense_prefix:
                    for token in input_ids.squeeze(0):
                        out = self.model(
                            input_ids=token.unsqueeze(0).unsqueeze(0),
                            use_cache=True,
                            return_dict=True,
                            past_key_values=past_key_values,
                        )
                        logits, past_key_values = out.logits, out.past_key_values
                    
                else:
                    for st in range(0, input_ids.size(1) - 1, chunk_size):
                        ed = min(input_ids.size(1) - 1, st + chunk_size)
                        out = self.model(
                            input_ids=input_ids[:, st:ed],
                            use_cache=True,
                            return_dict=True,
                            past_key_values=past_key_values,
                        )
                        logits, past_key_values = out.logits, out.past_key_values
                
                if update_global_past_kv:
                    self.past_kv = past_key_values

            else:
                out = self.model(
                    input_ids=input_ids[:, -1:],
                    past_key_values=past_key_values,
                    use_cache=True,
                    return_dict=True,
                )
                logits, past_key_values = out.logits, out.past_key_values

                if disable_golden_context and update_global_past_kv:
                    self.past_kv = past_key_values

            logits = logits[:, -1, :]
            word = logits.argmax(dim=-1)
            if word.item() in end_token_ids or i == max_length:
                break

            input_ids = torch.cat((input_ids, word.to(input_ids.device).view(1, 1)), dim=-1)

        # should see whether the last token is eos, if not tell self.test to add it to the next prompt
        if word.item() != self.tokenizer.eos_token_id and disable_golden_context:
            self.add_eos_to_next_prompt = True
        return input_ids

class GreedySearch_RetrAttn_Legacy(GreedySearch):
    def __init__(self, model, tokenizer, top_k, from_layer, with_minference=False):
        super().__init__(model, tokenizer)
        if with_minference:
            from sparse_retr_attn.modeling_llama_minference_with_retr import (
                hf_greedy_search_retr,
                VectorDB_KV_Cache
            )
        else:
            from sparse_retr_attn.modeling_llama_retr_attn import (
                hf_greedy_search_retr,
                VectorDB_KV_Cache
            )
        self.top_k = top_k
        self.from_layer = from_layer
        self.kv_class = VectorDB_KV_Cache
    
    def clear(self):
        self.past_kv = None
        self.kv_len = 0
        gc.collect()
        torch.cuda.empty_cache()

    def _decode(
        self,
        input_ids,
        max_length=100,
        extra_end_token_ids=[],
        dense_prefix=False,
        update_global_past_kv=True,
        disable_golden_context=False,
    ):
        if input_ids.dim() == 1:
            input_ids = input_ids[None, :]
        input_ids = input_ids.cuda()
        assert input_ids.size(0) == 1
        end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id]
        logits = None

        if self.past_kv is None:
            past_key_values = self.kv_class(
                max_length=max_length*self.num_turns+input_ids.size(1),
                temp_cache_size=max_length
            )
            kv_len = input_ids.size(1)
        else:
            past_key_values = self.past_kv
            kv_len = self.kv_len

        for i in range(max_length):
            if i == 0:
                if dense_prefix:
                    for token in input_ids.squeeze(0):
                        out = self.model(
                            input_ids=token.unsqueeze(0).unsqueeze(0),
                            use_cache=True,
                            return_dict=True,
                            past_key_values=past_key_values,

                            insert_db = True if update_global_past_kv else False,
                            top_k=self.top_k,
                            from_layer=self.from_layer,
                            cache_position=torch.tensor([kv_len], device=input_ids.device, dtype=torch.long),
                        )
                        logits, past_key_values = out.logits, out.past_key_values
                        kv_len += 1
                else:
                    out = self.model(
                        input_ids=input_ids,
                        use_cache=True,
                        return_dict=True,
                        past_key_values=past_key_values,

                        top_k=self.top_k,
                        from_layer=self.from_layer,
                        cache_position=torch.arange(kv_len, device=input_ids.device),
                    )
                    logits, past_key_values = out.logits, out.past_key_values
                    # kv_len += 1
                # update global past_kv with prefix only
                if update_global_past_kv:
                    self.past_kv = past_key_values
                    self.kv_len = kv_len

            else:
                out = self.model(
                    input_ids=input_ids[:, -1:],
                    past_key_values=past_key_values,
                    use_cache=True,
                    return_dict=True,
                    insert_db = False, 
                    top_k=self.top_k,
                    from_layer=self.from_layer,
                    cache_position=torch.tensor([kv_len], device=input_ids.device, dtype=torch.long),
                )
                logits, past_key_values = out.logits, out.past_key_values
                kv_len += 1
                if disable_golden_context and update_global_past_kv:
                    self.past_kv = past_key_values
                    self.kv_len = kv_len

            logits = logits[:, -1, :]
            word = logits.argmax(dim=-1)
            if word.item() in end_token_ids or i == max_length:
                break

            input_ids = torch.cat((input_ids, word.to(input_ids.device).view(1, 1)), dim=-1)

        self.past_kv.temp_seen = 0   # Discard decoding tokens

        # should see whether the last token is eos, if not tell self.test to add it to the next prompt
        if word.item() != self.tokenizer.eos_token_id and disable_golden_context:
            self.add_eos_to_next_prompt = True
        return input_ids

    def _encode(self, input_ids, max_length=None):
        if self.past_kv is None:
            past_key_values = self.kv_class(
                max_length=max_length*self.num_turns+input_ids.size(1),
                temp_cache_size=max_length+self.length_of_query+10
            )
            kv_len = input_ids.size(1)
        else:
            past_key_values = self.past_kv
            kv_len = self.kv_len

        out = self.model(
            input_ids=input_ids,
            use_cache=True,
            return_dict=True,
            past_key_values=past_key_values,

            top_k=self.top_k,
            from_layer=self.from_layer,
            cache_position=torch.arange(kv_len, device=input_ids.device),
        )
        _, past_key_values = out.logits, out.past_key_values

        self.past_kv = past_key_values
        self.kv_len = kv_len
    
    def test_scdq(self, example, max_length=100):
        prompts = example['prompts']
        self.length_of_query = len(self.tokenizer.encode(prompts[1]))
        self.num_turns = len(prompts)
        return super().test_scdq(example, max_length)
    
    def test(self, example, max_length=100, disable_golden_context=False):
        prompts = example['prompts']
        self.length_of_query = len(self.tokenizer.encode(prompts[1]))
        self.num_turns = len(prompts)
        return super().test(example, max_length, disable_golden_context)

class GreedySearch_Mamba2:
    def __init__(self, llm, tokenizer):
        self.llm = llm
        self.tokenizer = tokenizer
    
    def test_scdq(self, example, max_length=100):

        results = []
        for idx, prompt in enumerate(example['prompts']):
            if isinstance(max_length, dict):
                max_length_per_turn = max_length[example['task'][idx - 1]]
            else:
                max_length_per_turn = max_length

            generation_config = GenerationConfig(
                max_new_tokens=max_length_per_turn,
                num_return_sequences=1,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id,
            )
            
            if idx == 0:
                init_prompt_ids = prompt
            else:
                current_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
                input_ids = init_prompt_ids + current_ids
                input_ids = torch.tensor(input_ids).unsqueeze(0).to(self.llm.device)

                outputs = self.llm.generate(
                    input_ids=input_ids, generation_config=generation_config
                )
                output = outputs[0, len(input_ids[0]): ]
                output = self.tokenizer.decode(output, skip_special_tokens=True)
                output = output.strip()
                results.append(output)

        output = {
            'answers': results,
            'gt': example['ground_truth']
        }

        if isinstance(max_length, dict): # mixed task setting
            output["task"] = example["task"]
        
        return output
            
    def test(self, example, max_length=100, disable_golden_context=False):

        results = []
        for idx, prompt in enumerate(example['prompts']):
            if isinstance(max_length, dict):
                max_length_per_turn = max_length[example['task'][idx]]
            else:
                max_length_per_turn = max_length
            
            generation_config = GenerationConfig(
                max_new_tokens=max_length_per_turn,
                num_return_sequences=1,
                do_sample=False,
                pad_token_id=self.tokenizer.pad_token_id,
            )
            
            if idx == 0:
                input_ids = prompt
                input_ids = torch.tensor(input_ids).unsqueeze(0).to(self.llm.device)
            else:
                current_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
                current_ids = torch.tensor(current_ids).unsqueeze(0).to(self.llm.device)

                if disable_golden_context:
                    # input_ids = input_ids + self.tokenizer.encode(results[-1], add_special_tokens=False) + [self.tokenizer.eos_token_id]
                    prev_ids = torch.tensor(self.tokenizer.encode(results[-1], add_special_tokens=False)).unsqueeze(0).to(self.llm.device)
                    eos_id = torch.tensor([self.tokenizer.eos_token_id], device=self.llm.device).unsqueeze(0)
                    input_ids = torch.cat((input_ids, prev_ids, eos_id), dim=-1)
                input_ids = torch.cat((input_ids, current_ids), dim=-1)
            outputs = self.llm.generate(
                input_ids=input_ids, generation_config=generation_config
            )
            output = outputs[0, len(input_ids[0]):]
            output = self.tokenizer.decode(output, skip_special_tokens=True)
            output = output.strip()
            results.append(output)
            torch.cuda.empty_cache()

        output = {
            'answers': results,
            'gt': example['ground_truth']
        }

        if isinstance(max_length, dict): # mixed task setting
            output["task"] = example["task"]
        
        return output