# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
from pathlib import Path
from typing import Optional

import torch
from datasets import load_dataset, load_from_disk
from fire import Fire
from infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from longbench.calculate_metrics import calculate_metrics as longbench_scorer
from longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from loogle.calculate_metrics import calculate_metrics as loogle_scorer
from scbench.caculate_metrics import calculate_metrics as scbench_scorer

from ruler.calculate_metrics import calculate_metrics as ruler_scorer
from tqdm import tqdm
from transformers import pipeline
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer

from kvpress import (
    AdaKVPress,
    ChunkKVPress,
    ComposedPress,
    CriticalAdaKVPress,
    CriticalKVPress,
    DuoAttentionPress,
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    ThinKPress,
    TOVAPress,
    QFilterPress,
    PyramidKVPress,
    QuestionKVPress,
    OracleKVPress,
    AdaOracleKVPress
)

logger = logging.getLogger(__name__)

MODEL_PATH_DICT = {
    "meta-llama/Meta-Llama-3.1-8B-Instruct": "../../models/Meta-Llama-3.1-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2": "../../models/Mistral-7B-Instruct-v0.2",
    "Qwen/Qwen2.5-7B-Instruct-1M": "../../models/Qwen2.5-7B-Instruct-1M",
}

DATASET_DICT = {
    "loogle": "simonjegou/loogle",
    "ruler": "simonjegou/ruler",
    "zero_scrolls": "simonjegou/zero_scrolls",
    "infinitebench": "MaxJeblick/InfiniteBench",
    "longbench": "Xnhyacinth/LongBench",
    "longbench-e": "Xnhyacinth/LongBench",
    "longbench-v2": "Xnhyacinth/LongBench-v2",
    "scbench_mt": "../../models/scbench_mt",
    "scbench_nm": "../../models/scbench_nm",
}

SCORER_DICT = {
    "loogle": loogle_scorer,
    "ruler": ruler_scorer,
    "zero_scrolls": zero_scrolls_scorer,
    "infinitebench": infinite_bench_scorer,
    "longbench": longbench_scorer,
    "longbench-e": longbench_scorer_e,
    "longbench-v2": longbenchv2_scorer,
    "scbench_mt": scbench_scorer,
    "scbench_nm": scbench_scorer,
}

PRESS_DICT = {
    "criti_adasnapkv": CriticalAdaKVPress(SnapKVPress()),
    "criti_ada_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
    "criti_snapkv": CriticalKVPress(SnapKVPress()),
    "criti_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
    "adasnapkv": AdaKVPress(SnapKVPress()),
    "ada_expected_attention": AdaKVPress(ExpectedAttentionPress()),
    "expected_attention": ExpectedAttentionPress(),
    "ada_expected_attention_e2": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
    "knorm": KnormPress(),
    "observed_attention": ObservedAttentionPress(),
    "random": RandomPress(),
    "snapkv": SnapKVPress(),
    "streaming_llm": StreamingLLMPress(),
    "think": ThinKPress(),
    "tova": TOVAPress(),
    "duo_attention": DuoAttentionPress(),
    "duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
    "chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
    "qfilter": QFilterPress(),
    "snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
    "pyramidkv": PyramidKVPress(),
    "questionkv": QuestionKVPress(),
    "oraclekv": OracleKVPress(),
    "adaoraclekv": AdaOracleKVPress(OracleKVPress()),
}

INSTRUCTION_DICT = {
    -1: None,
    101: "Next, you will be asked with some questions about the context above. These questions will ask you to summarize the above context includes: question answering, summarization, code completion, in-context learning, paragraph counting, retrival, etc. Please remember the relevant information and answering the question.",
    102: "Next, you will be asked with some questions about the context above. These questions includes: question answering, summarization, code completion, in-context learning, paragraph counting, retrival, etc. Please remember the relevant information and answering the question.",
    0: "Next, you will be presented with a series of questions regarding the context above, including specific details about the narrative, content, and numerical information. Please retain these details and provide accurate responses.",
    1: "Next, you will be presented with a series of questions regarding the context above. including specific details about the narrative, content, numerical information and key global information. The questions may also involve small fragmented relationships from the context, including term relationships, causal relations, and temporal relations. Please retain these details and provide accurate responses. ",
    2: "Next, you will be presented with a series of questions regarding the context above. Pay attention to specific details like names, places, and numbers (e.g., dates, quantities). Try to grasp the overall message or main theme, such as the primary purpose discussed. Note how people, events, or concepts are related, including family ties or event linkages. Be aware of pronouns (e.g., he, she, it, they) and what they refer to, clarifying discussions.Identify important times, dates, or events mentioned, crucial for sequence or significance. Recognize people, organizations, or entities involved in key events or actions. Be ready to infer information not explicitly stated, such as logical conclusions from context. Please retain these details and provide accurate responses. ",
    3: "Please remember the specific information of context above: specific details like names, places, and numbers, main theme, like overall message, relations, like family ties and event linkages, details, like grammar dependencies and narrative information. Please retain these details and provide accurate responses for following questions.",
    4: "Next, you will be presented with a series of questions regarding the context above. Please remember the following information: 1. specific details like names, places, and numbers; 2. main theme, like overall message; 3. relations, like family ties and event linkages, 4. semantic details, like grammar dependencies and narrative information between words.",
    5: "Next, you will be presented with a series of questions regarding the context above. Please remember the following information: 1. specific details like names, places, and numbers; 2. main theme, like overall message; 3. relations, like family ties and event linkages, 4. semantic details, like grammar dependencies and narrative information between words. 5.logical chains, such as dependencies in program code.",
    6: "Please remember the following information in above context: 1. specific details like names, places, and numbers; 2. main theme, like overall message; 3. relations, like family ties and event linkages, 4. semantic details, like grammar dependencies and narrative information between words. 5.logical chains, such as dependencies in program code or causal relations in events/things.",
    7: "Next, you will be asked with some questions about the context above. These questions may involve the following information: 1. specific details like named entities, places, and numbers; 2. main theme of the context, like narratives structures and overall messages; 3. relations, like family ties and event linkages;  4. semantic details, like grammar dependencies and narrative information between words; 5. structural patterns, code syntax and patterns in parallel sentences or paragraphs; 6. temporal/contextual features, such as causal relations. Please take care of these information and answer the following questions.",
    8: "Next, you will be asked questions about the context above. These questions may involve various types of information, including:\n 1.Specific details such as named entities, places, and numbers.\n 2.The main theme of the context, including narrative structures and overall messages.\n Relations between entities, such as family ties or event linkages.\n 3.Semantic details, including grammar dependencies and narrative information between words.\n 4.Structural patterns, such as code syntax or patterns in parallel sentences or paragraphs.\n 5.Temporal and contextual features, such as causal relations.\n Given that users often engage with long contexts in tasks requiring deep understanding and iterative refinement, it is crucial to maintain coherence and relevance throughout the entire context. Pay particular attention to how different parts of the context relate to each other and how they contribute to the overall meaning or task at hand. Be prepared to answer questions that require integrating information from multiple sections of the context, as this reflects how users typically interact with such information. Additionally, be aware that questions may require multi-step reasoning or retrieving and integrating information from various parts of the context, as these are common tasks in long-context processing.",
    9: "Next, you will be asked some questions regarding the context(s) above. These questions will cover: 1. Specific details about the narrative, content, and numerical information. 2. Key global information. 3. Small fragmented relationships within the context, including term relationships, causal relations, and temporal relations. 4. Timeline, focusing on the key event occurrences. 5. Frequency of words and extreme values.",
    10: "Analyze the given text carefully. Your tasks include: 1) Answering factual questions accurately. 2) Generating concise summaries. 3) Demonstrating in-context learning. 4) Writing code based on the text. 5) Counting paragraphs. 6) Retrieving specific strings. 7) Extracting numerical values. 8) Selecting correct answers in multiple-choice questions. 9) Calculating extreme values from arrays. Always ensure your answers are strictly based on the provided text.",
    11: "Carefully read and analyze the provided text. Your tasks involve multiple types of questions, each requiring precise information extraction from the text. Specifically, you will: 1) Answer factual questions by identifying accurate details directly from the text. 2) Generate concise and coherent summaries without introducing any external information. 3) Demonstrate in-context learning by recognizing patterns or concepts reflected in the text. 4) Write code accurately based on the textual instructions or examples. 5) Count the total number of paragraphs accurately. 6) Search and retrieve specific strings or terms mentioned in the text. 7) Extract and list numerical values, maintaining their original form. 8) Solve multiple-choice questions by selecting the most accurate answer based on the content. 9) Identify and calculate extreme values (maximum, minimum) from any given array of numbers in the text. Always ensure that your responses are strictly grounded in the provided text. Do not infer, assume, or generate information beyond what is explicitly stated. Maintain clarity, accuracy, and completeness in your answers. Stay focused on the input context and prioritize factual consistency."
}


def evaluate(
    dataset: str,
    data_dir: Optional[str] = None,
    model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
    device: Optional[str] = None,
    press_name: str = "expected_attention",
    compression_ratio: float = 0.1,
    fraction: float = 1.0,
    max_new_tokens: Optional[int] = None,
    max_context_length: Optional[int] = None,
    compress_questions: bool = False,
    key_channel_compression_ratio: float = 0.5,
    instruction: Optional[int] = -1,
    window_size: Optional[int] = 64,
    save_folder: Optional[str] = "results_60k",
):
    """
    Evaluate a model on a dataset using a press and save the results

    Parameters
    ----------
    dataset : str
        Dataset to evaluate
    data_dir : str, optional
        Subdirectory of the dataset to evaluate, by default None
    model : str, optional
        Model to use, by default "meta-llama/Meta-Llama-3.1-8B-Instruct"
    device : str, optional
        Model device, by default cuda:0 if available else cpu. For multi-GPU use "auto"
    press_name : str, optional
        Press to use (see PRESS_DICT), by default "expected_attention"
    compression_ratio : float, optional
        Compression ratio for the press, by default 0.1
    max_new_tokens : int, optional
        Maximum number of new tokens to generate, by default use the default for the task (recommended)
    fraction : float, optional
        Fraction of the dataset to evaluate, by default 1.0
    max_context_length : int, optional
        Maximum number of tokens to use in the context. By default will use the maximum length supported by the model.
    compress_questions : bool, optional
        Whether to compress the questions as well, by default False
    key_channel_compression_ratio : float, optional
        key Channel Compression ratio for the channel press, by default 0.5
    instruction : bool, optional
        instruction mode, by default False
    window_size : int, optional
    """

    assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
    assert dataset in SCORER_DICT, f"No scorer found for {dataset}"
    data_dir = str(data_dir) if data_dir else None

    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"

    save_dir = Path(__file__).parent / save_folder
    save_dir.mkdir(exist_ok=True)
    if instruction != -1:
        save_filename = save_dir / (
        "__".join([dataset, data_dir if data_dir else "", model.replace("/", "--"), press_name+f"__instruct_{instruction:02d}"+"_question", str(compression_ratio)])
        + ".csv"
    )
    else:
        save_filename = save_dir / (
            "__".join([dataset, data_dir if data_dir else "", model.replace("/", "--")+"_question", press_name, str(compression_ratio)])
            + ".csv"
        )
    instruction = INSTRUCTION_DICT[instruction]
    if save_filename.exists():
        logger.warning(f"Results already exist at {save_filename}")

    # Load dataframe

    if "scbench" in dataset:
        df = load_from_disk(DATASET_DICT[dataset]).to_pandas()
    else:
        df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas()

    if fraction < 1.0:
        df = df.sample(frac=fraction, random_state=42)
        save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix)

    if max_context_length is not None:
        save_filename = save_filename.with_name(
            save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix
        )

    if compress_questions:
        df["context"] = df["context"] + df["question"]
        df["question"] = ""
        save_filename = save_filename.with_name(save_filename.stem + "__compressed_questions" + save_filename.suffix)

    # Load press
    assert press_name in PRESS_DICT
    press = PRESS_DICT[press_name]

    if isinstance(press, (DuoAttentionPress)):
        press.head_compression_ratio = compression_ratio
    elif isinstance(press, (ComposedPress)):
        for ps in press.presses:
            if isinstance(ps, (ThinKPress)):
                ps.key_channel_compression_ratio = key_channel_compression_ratio
                save_filename = save_filename.with_name(
                    save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
                )
            else:
                ps.compression_ratio = compression_ratio  # type:ignore[attr-defined]
    elif isinstance(press, (ThinKPress)):
        press.key_channel_compression_ratio = key_channel_compression_ratio
        save_filename = save_filename.with_name(
            save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
        )
    elif isinstance(press, SnapKVPress):
        press.window_size = window_size
        press.compression_ratio = compression_ratio
    else:
        press.compression_ratio = compression_ratio  # type:ignore[attr-defined]

    # Initialize pipeline with the correct attention implementation
    model_kwargs = {"torch_dtype": "auto"}
    if isinstance(press, ObservedAttentionPress):
        model_kwargs["attn_implementation"] = "eager"
    else:
        try:
            import flash_attn  # noqa: F401

            model_kwargs["attn_implementation"] = "flash_attention_2"
        except ImportError:
            pass
    model = MODEL_PATH_DICT[model]
    if device == "auto":
        pipe = pipeline("kv-press-text-generation-with-oracle", model=model, device_map="auto", model_kwargs=model_kwargs)
    else:
        pipe = pipeline("kv-press-text-generation-with-oracle", model=model, device=device, model_kwargs=model_kwargs)
    # Run pipeline on each context
    df["predicted_answer"] = None
    if "scbench" in dataset:
        df_context = df.groupby("__index_session__")
        n_total = df["__index_session__"].nunique()
    else:
        df_context = df.groupby("context")
        n_total = df["context"].nunique()
    # assert all(df_context["answer_prefix"].nunique() == 1)
    # cnt = 0
    # max_len = 0
    # min_len = 10000

    for context, df_ in tqdm(df_context, total=n_total):
        # if cnt == 316:
        #     print()
        #     cnt+=1
        #     continue
        # else:
        #     cnt += 1
        #     continue

        if "scbench" in dataset:
            context = df_["context"].values[0]

        questions = df_["question"].to_list()
        max_new_tokens_ = max_new_tokens if max_new_tokens is not None else df_["max_new_tokens"].iloc[0]
        answer_prefix = df_["answer_prefix"].iloc[0]
        # ctx_length = len(pipe.tokenizer.encode(context, return_tensors="pt", add_special_tokens=False)[0])
        # min_len = min(min_len, ctx_length)
        # continue
        # cxt_length = len(pipe.tokenizer.encode(context, return_tensors="pt", add_special_tokens=False)[0])
        # if cnt == 105 or cnt == 104 or cnt == 103:
        #     print(cxt_length)
        # max_len = max(max_len, cxt_length)
        # ids = cxt_length // 1000
        # if ids > 60:
        #     cnt += 1
        # print(cxt_length)
        # continue
        try:
            output = pipe(
                context,
                questions=questions,
                answer_prefix=answer_prefix,
                press=press,
                max_new_tokens=max_new_tokens_,
                max_context_length=max_context_length,
                instruction=instruction,
            )
            df.loc[df_.index, "predicted_answer"] = output["answers"]
            df.loc[df_.index, "compression_ratio"] = press.compression_ratio  # type:ignore[attr-defined]
        except Exception as e:
            print(e)
            df.loc[df_.index, "predicted_answer"] = ""
            df.loc[df_.index, "compression_ratio"] = press.compression_ratio
        # torch.cuda.empty_cache()
    # print("min_len", min_len)
    # exit(0)
    # print("cnt", cnt)
    # exit(0)

    # Save answers
    df[["predicted_answer", "compression_ratio"]].to_csv(str(save_filename), index=False)

    # Calculate metrics
    scorer = SCORER_DICT[dataset]
    metrics = scorer(df)
    with open(str(save_filename).replace(".csv", ".json"), "w") as f:
        json.dump(metrics, f)
    print(f"Average compression ratio: {df['compression_ratio'].mean():.2f}")
    print(metrics)


if __name__ == "__main__":
    # import warnings
    # warnings.filterwarnings("ignore")
    Fire(evaluate)
