# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
from pathlib import Path
from typing import Optional

import torch
from datasets import load_dataset, load_from_disk
from fire import Fire
from infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from longbench.calculate_metrics import calculate_metrics as longbench_scorer
from longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from loogle.calculate_metrics import calculate_metrics as loogle_scorer
from scbench.caculate_metrics import calculate_metrics as scbench_scorer

from ruler.calculate_metrics import calculate_metrics as ruler_scorer
from tqdm import tqdm
from transformers import pipeline
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer

from kvpress import (
    AdaKVPress,
    ChunkKVPress,
    ComposedPress,
    CriticalAdaKVPress,
    CriticalKVPress,
    DuoAttentionPress,
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    ThinKPress,
    TOVAPress,
    QFilterPress,
    PyramidKVPress,
)

logger = logging.getLogger(__name__)

MODEL_PATH_DICT = {
    "meta-llama/Meta-Llama-3.1-8B-Instruct": "../../models/Meta-Llama-3.1-8B-Instruct",
    "mistralai/Mistral-7B-Instruct-v0.2": "../../models/Mistral-7B-Instruct-v0.2",
    "Qwen/Qwen2.5-7B-Instruct-1M": "../../models/Qwen2.5-7B-Instruct-1M",
}

DATASET_DICT = {
    "loogle": "simonjegou/loogle",
    "ruler": "simonjegou/ruler",
    "zero_scrolls": "simonjegou/zero_scrolls",
    "infinitebench": "MaxJeblick/InfiniteBench",
    "longbench": "Xnhyacinth/LongBench",
    "longbench-e": "Xnhyacinth/LongBench",
    "longbench-v2": "Xnhyacinth/LongBench-v2",
    "scbench_mt": "../../models/scbench_mt",
    "scbench_nm": "../../models/scbench_nm",
}

SCORER_DICT = {
    "loogle": loogle_scorer,
    "ruler": ruler_scorer,
    "zero_scrolls": zero_scrolls_scorer,
    "infinitebench": infinite_bench_scorer,
    "longbench": longbench_scorer,
    "longbench-e": longbench_scorer_e,
    "longbench-v2": longbenchv2_scorer,
    "scbench_mt": scbench_scorer,
    "scbench_nm": scbench_scorer,
}

PRESS_DICT = {
    "criti_adasnapkv": CriticalAdaKVPress(SnapKVPress()),
    "criti_ada_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
    "criti_snapkv": CriticalKVPress(SnapKVPress()),
    "criti_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
    "adasnapkv": AdaKVPress(SnapKVPress()),
    "ada_expected_attention": AdaKVPress(ExpectedAttentionPress()),
    "expected_attention": ExpectedAttentionPress(),
    "ada_expected_attention_e2": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
    "knorm": KnormPress(),
    "observed_attention": ObservedAttentionPress(),
    "random": RandomPress(),
    "snapkv": SnapKVPress(),
    "streaming_llm": StreamingLLMPress(),
    "think": ThinKPress(),
    "tova": TOVAPress(),
    "duo_attention": DuoAttentionPress(),
    "duo_attention_on_the_fly": DuoAttentionPress(on_the_fly_scoring=True),
    "chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
    "qfilter": QFilterPress(),
    "snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
    "pyramidkv": PyramidKVPress(),
}

INSTRUCTION_DICT = {
    -1: None,
    0: "Next, you will be presented with a series of questions regarding the context above, including specific details about the narrative, content, and numerical information. Please retain these details and provide accurate responses.",
    1: "Next, you will be presented with a series of questions regarding the context above. including specific details about the narrative, content, numerical information and key global information. The questions may also involve small fragmented relationships from the context, including term relationships, causal relations, and temporal relations. Please retain these details and provide accurate responses. ",
    2: "Next, you will be presented with a series of questions regarding the context above. Pay attention to specific details like names, places, and numbers (e.g., dates, quantities). Try to grasp the overall message or main theme, such as the primary purpose discussed. Note how people, events, or concepts are related, including family ties or event linkages. Be aware of pronouns (e.g., he, she, it, they) and what they refer to, clarifying discussions.Identify important times, dates, or events mentioned, crucial for sequence or significance. Recognize people, organizations, or entities involved in key events or actions. Be ready to infer information not explicitly stated, such as logical conclusions from context. Please retain these details and provide accurate responses. ",
    3: "Please remember the specific information of context above: specific details like names, places, and numbers, main theme, like overall message, relations, like family ties and event linkages, details, like grammar dependencies and narrative information. Please retain these details and provide accurate responses for following questions.",
    4: "Next, you will be presented with a series of questions regarding the context above. Please remember the following information: 1. specific details like names, places, and numbers; 2. main theme, like overall message; 3. relations, like family ties and event linkages, 4. semantic details, like grammar dependencies and narrative information between words.",
    5: "Next, you will be presented with a series of questions regarding the context above. Please remember the following information: 1. specific details like names, places, and numbers; 2. main theme, like overall message; 3. relations, like family ties and event linkages, 4. semantic details, like grammar dependencies and narrative information between words. 5.logical chains, such as dependencies in program code.",
    6: "Please remember the following information in above context: 1. specific details like names, places, and numbers; 2. main theme, like overall message; 3. relations, like family ties and event linkages, 4. semantic details, like grammar dependencies and narrative information between words. 5.logical chains, such as dependencies in program code or causal relations in events/things.",
    88: "Next, you will be presented with a series of questions regarding the context above. These question involves: 1. specific details like names, places; 2. the unusual information. 3. things about San Francisco. 4. semantic details, like grammar dependencies. 5. facts/sentences detached from the main theme",
}


def evaluate(
    dataset: str,
    data_dir: Optional[str] = None,
    model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
    device: Optional[str] = None,
    press_name: str = "expected_attention",
    compression_ratio: float = 0.1,
    fraction: float = 1.0,
    max_new_tokens: Optional[int] = None,
    max_context_length: Optional[int] = None,
    compress_questions: bool = False,
    key_channel_compression_ratio: float = 0.5,
    instruction: Optional[int] = -1,
    window_size: Optional[int] = 64,
    save_folder: Optional[str] = "results_60k",
):
    """
    Evaluate a model on a dataset using a press and save the results

    Parameters
    ----------
    dataset : str
        Dataset to evaluate
    data_dir : str, optional
        Subdirectory of the dataset to evaluate, by default None
    model : str, optional
        Model to use, by default "meta-llama/Meta-Llama-3.1-8B-Instruct"
    device : str, optional
        Model device, by default cuda:0 if available else cpu. For multi-GPU use "auto"
    press_name : str, optional
        Press to use (see PRESS_DICT), by default "expected_attention"
    compression_ratio : float, optional
        Compression ratio for the press, by default 0.1
    max_new_tokens : int, optional
        Maximum number of new tokens to generate, by default use the default for the task (recommended)
    fraction : float, optional
        Fraction of the dataset to evaluate, by default 1.0
    max_context_length : int, optional
        Maximum number of tokens to use in the context. By default will use the maximum length supported by the model.
    compress_questions : bool, optional
        Whether to compress the questions as well, by default False
    key_channel_compression_ratio : float, optional
        key Channel Compression ratio for the channel press, by default 0.5
    instruction : bool, optional
        instruction mode, by default False
    window_size : int, optional
    """

    assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
    assert dataset in SCORER_DICT, f"No scorer found for {dataset}"
    data_dir = str(data_dir) if data_dir else None

    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"

    save_dir = Path(__file__).parent / save_folder
    save_dir.mkdir(exist_ok=True)
    if instruction != -1:
        save_filename = save_dir / (
        "__".join([dataset, data_dir if data_dir else "", model.replace("/", "--"), press_name+f"__instruct_{instruction:02d}", str(compression_ratio)])
        + ".csv"
    )
    else:
        save_filename = save_dir / (
            "__".join([dataset, data_dir if data_dir else "", model.replace("/", "--"), press_name, str(compression_ratio)])
            + ".csv"
        )
    instruction = INSTRUCTION_DICT[instruction]
    if save_filename.exists():
        logger.warning(f"Results already exist at {save_filename}")

    # Load dataframe

    if "scbench" in dataset:
        df = load_from_disk(DATASET_DICT[dataset]).to_pandas()
    else:
        df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas()

    if fraction < 1.0:
        df = df.sample(frac=fraction, random_state=42)
        save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix)

    if max_context_length is not None:
        save_filename = save_filename.with_name(
            save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix
        )

    if compress_questions:
        df["context"] = df["context"] + df["question"]
        df["question"] = ""
        save_filename = save_filename.with_name(save_filename.stem + "__compressed_questions" + save_filename.suffix)

    # Load press
    assert press_name in PRESS_DICT
    press = PRESS_DICT[press_name]

    if isinstance(press, (DuoAttentionPress)):
        press.head_compression_ratio = compression_ratio
    elif isinstance(press, (ComposedPress)):
        for ps in press.presses:
            if isinstance(ps, (ThinKPress)):
                ps.key_channel_compression_ratio = key_channel_compression_ratio
                save_filename = save_filename.with_name(
                    save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
                )
            else:
                ps.compression_ratio = compression_ratio  # type:ignore[attr-defined]
    elif isinstance(press, (ThinKPress)):
        press.key_channel_compression_ratio = key_channel_compression_ratio
        save_filename = save_filename.with_name(
            save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
        )
    elif isinstance(press, SnapKVPress):
        press.window_size = window_size
        press.compression_ratio = compression_ratio
    else:
        press.compression_ratio = compression_ratio  # type:ignore[attr-defined]

    # Initialize pipeline with the correct attention implementation
    model_kwargs = {"torch_dtype": "auto"}
    if isinstance(press, ObservedAttentionPress):
        model_kwargs["attn_implementation"] = "eager"
    else:
        try:
            import flash_attn  # noqa: F401
            model_kwargs["attn_implementation"] = "flash_attention_2"
        except ImportError:
            pass
    model = MODEL_PATH_DICT[model]
    if device == "auto":
        pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
    else:
        pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)
    # Run pipeline on each context
    df["predicted_answer"] = None
    if "scbench" in dataset:
        df_context = df.groupby("__index_session__")
        n_total = df["__index_session__"].nunique()
    else:
        df_context = df.groupby("context")
        n_total = df["context"].nunique()
    # assert all(df_context["answer_prefix"].nunique() == 1)
    # cnt = 0
    # max_len = 0
    # min_len = 10000

    for context, df_ in tqdm(df_context, total=n_total):
        # if cnt == 316:
        #     print()
        #     cnt+=1
        #     continue
        # else:
        #     cnt += 1
        #     continue

        if "scbench" in dataset:
            context = df_["context"].values[0]

        questions = df_["question"].to_list()
        max_new_tokens_ = max_new_tokens if max_new_tokens is not None else df_["max_new_tokens"].iloc[0]
        answer_prefix = df_["answer_prefix"].iloc[0]
        # ctx_length = len(pipe.tokenizer.encode(context, return_tensors="pt", add_special_tokens=False)[0])
        # min_len = min(min_len, ctx_length)
        # continue
        # cxt_length = len(pipe.tokenizer.encode(context, return_tensors="pt", add_special_tokens=False)[0])
        # if cnt == 105 or cnt == 104 or cnt == 103:
        #     print(cxt_length)
        # max_len = max(max_len, cxt_length)
        # ids = cxt_length // 1000
        # if ids > 60:
        #     cnt += 1
        # print(cxt_length)
        # continue
        output = pipe(
            context,
            questions=questions,
            answer_prefix=answer_prefix,
            press=press,
            max_new_tokens=max_new_tokens_,
            max_context_length=max_context_length,
            instruction=instruction,
        )
        df.loc[df_.index, "predicted_answer"] = output["answers"]
        df.loc[df_.index, "compression_ratio"] = press.compression_ratio  # type:ignore[attr-defined]
        # torch.cuda.empty_cache()
    # print("min_len", min_len)
    # exit(0)
    # print("cnt", cnt)
    # exit(0)

    # Save answers
    df[["predicted_answer", "compression_ratio"]].to_csv(str(save_filename), index=False)

    # Calculate metrics
    scorer = SCORER_DICT[dataset]
    metrics = scorer(df)
    with open(str(save_filename).replace(".csv", ".json"), "w") as f:
        json.dump(metrics, f)
    print(f"Average compression ratio: {df['compression_ratio'].mean():.2f}")
    print(metrics)


if __name__ == "__main__":
    # import warnings
    # warnings.filterwarnings("ignore")
    Fire(evaluate)
