# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
from pathlib import Path
from typing import Optional

import torch
from datasets import load_dataset, load_from_disk
import pandas as pd
from fire import Fire
import transformers
from longbench.calculate_metrics import calculate_metrics as longbench_scorer
from kvpress.ada_attn import replace_var_flash_attn
from ruler.calculate_metrics import calculate_metrics as ruler_scorer
from tqdm import tqdm
from transformers import pipeline
import warnings

warnings.filterwarnings(
    "ignore", message="MatMul8bitLt: inputs will be cast from torch.bfloat16 to float16 during quantization"
)

from kvpress import (
    AdaKVPress,
    CriticalKVPress,
    SnapKVPress,
    StreamingLLMPress,
    EfficientAdaSnapKVPress,
    EfficientAdaScorerPress,
    EfficientAdaCriticalKVPress,
)

logger = logging.getLogger(__name__)

DATASET_DICT = {
    "ruler": "simonjegou/ruler",
    "longbench": None,
}

SCORER_DICT = {
    "ruler": ruler_scorer,
    "longbench": longbench_scorer,
}

PRESS_DICT = {
    "ada_snapkv": AdaKVPress(SnapKVPress()),
    "snapkv": SnapKVPress(),
    "streaming_llm": StreamingLLMPress(),
    "efficient_ada_snapkv": EfficientAdaSnapKVPress(),
    "criticalkv": CriticalKVPress(SnapKVPress()),
    "efficient_ada_criticalkv": EfficientAdaCriticalKVPress(SnapKVPress()),
}


def evaluate(
    dataset: str = "ruler",
    data_dir: Optional[str] = None,
    model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
    device: Optional[str] = None,
    press_name: str = "snapkv",
    compression_ratio: float = 0.75,
    fraction: float = 1.0,
    max_new_tokens: Optional[int] = None,
    max_context_length: Optional[int] = None,
    compress_questions: bool = False,
    Use_8bit: bool = False,
):
    """
    Evaluate a model on a dataset using a press and save the results

    Parameters
    ----------
    dataset : str
        Dataset to evaluate
    data_dir : str, optional
        Subdirectory of the dataset to evaluate, by default None
    model : str, optional
        Model to use, by default "meta-llama/Meta-Llama-3.1-8B-Instruct"
    device : str, optional
        Model device, by default cuda:0 if available else cpu. For multi-GPU use "auto"
    press_name : str, optional
        Press to use (see PRESS_DICT), by default "expected_attention"
    compression_ratio : float, optional
        Compression ratio for the press, by default 0.1
    max_new_tokens : int, optional
        Maximum number of new tokens to generate, by default use the default for the task (recommended)
    fraction : float, optional
        Fraction of the dataset to evaluate, by default 1.0
    max_context_length : int, optional
        Maximum number of tokens to use in the context. By default will use the maximum length supported by the model.
    compress_questions : bool, optional
        Whether to compress the questions as well, by default False
    """
    assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
    assert dataset in SCORER_DICT, f"No scorer found for {dataset}"
    data_dir = str(data_dir) if data_dir else None
    # Load press
    if press_name is not None:
        assert press_name in PRESS_DICT
        press = PRESS_DICT[press_name]
        press.compression_ratio = compression_ratio  # type:ignore[attr-definedif press is not None
    else:
        press = None
    # type:ignore[attr-defined]

    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"

    save_dir = Path(__file__).parent / "results"
    save_dir.mkdir(exist_ok=True)
    save_filename = save_dir / (
        "__".join(
            [
                dataset,
                model.replace("/", "--"),
                str(press),
                str(compression_ratio),
                f"frac{fraction:.2f}",
            ]
        )
        + ".csv"
    )
    print("try save to:", save_filename)
    if save_filename.exists():
        logger.warning(f"Results already exist at {save_filename}")
        print("Results already exist at", save_filename)
        exit()

    # Load dataframe
    try:
        print("Loading from disk, data_dir:", data_dir)
        df = load_from_disk(data_dir).to_pandas()
    except Exception as e:
        print(f"Failed to load from disk: {e}")
        exit()

    if fraction < 1.0:
        # Stratified sampling by task category
        sampled_dfs = []
        for task_name, task_df in df.groupby("task"):
            sampled_task_df = task_df.sample(frac=fraction, random_state=42)
            sampled_dfs.append(sampled_task_df)
        df = pd.concat(sampled_dfs)
        save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix)

    if max_context_length is not None:
        save_filename = save_filename.with_name(
            save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix
        )

    if compress_questions:
        df["context"] = df["context"] + df["question"]
        df["question"] = ""
        save_filename = save_filename.with_name(save_filename.stem + "__compressed_questions" + save_filename.suffix)

    # Initialize pipeline with the correct attention implementation
    model_kwargs = {}
    if Use_8bit:
        model_kwargs["quantization_config"] = transformers.BitsAndBytesConfig(
            load_in_8bit=True, llm_int8_skip_modules=["lm_head", "attn"]
        )
        device = "auto"
        print("using auto device mapping")
    # Support AdaKV
    if isinstance(press, EfficientAdaScorerPress):
        replace_var_flash_attn(model_name=model)
    else:
        try:
            import flash_attn  # noqa: F401

            model_kwargs = {"attn_implementation": "flash_attention_2"}
        except ImportError:
            pass

    model_kwargs["torch_dtype"] = "auto"
    if device == "auto":
        pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
    else:
        pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

    print("model dtype: ", pipe.model.dtype, flush=True)
    # Run pipeline on each context
    df["predicted_answer"] = None
    df_context = df.groupby("context")
    assert all(df_context["answer_prefix"].nunique() == 1)

    if dataset == "longbench":  
        evalutated_tasks = None
    elif dataset == "ruler":
        # evalutated_tasks = ["niah_multikey_3", "niah_single_3"]
        # evalutated_tasks = ["niah_multikey_3"]
        # evalutated_tasks = ["cwe"]
        # evalutated_tasks = ["qa_2"]
        evalutated_tasks = ["niah_multivalue", "niah_single_2"]
        # evalutated_tasks = None
    else:
        evalutated_tasks = None

    for context, df_ in tqdm(df_context, total=df["context"].nunique()):

        task_name = df_["task"].iloc[0]

        # skip specific tasks, which are not in the task_names
        if evalutated_tasks is not None:
            if task_name not in evalutated_tasks:
                continue

        chat_template_bak = pipe.tokenizer.chat_template
        bos_bak = pipe.tokenizer.bos_token
        gen_config_eos_id_bak = pipe.model.generation_config.eos_token_id

        if task_name in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
            pipe.tokenizer.chat_template = None
            pipe.tokenizer.bos_token = ""
            if task_name in ["samsum"]:
                pipe.model.generation_config.eos_token_id = [
                    pipe.tokenizer.eos_token_id,
                    pipe.tokenizer.encode("\n", add_special_tokens=False)[-1],
                ]

        questions = df_["question"].to_list()
        max_new_tokens_ = max_new_tokens if max_new_tokens is not None else df_["max_new_tokens"].iloc[0]
        answer_prefix = df_["answer_prefix"].iloc[0]
        Failure_count = 0
        try:
            output = pipe(
                context,
                questions=questions,
                answer_prefix=answer_prefix,
                press=press,
                max_new_tokens=max_new_tokens_,
                max_context_length=max_context_length,
            )
        except Exception as e:
            print("An error occurred:", e)
            output = {"answers": "Failure:" + str(e)}
            Failure_count += 1

        df.loc[df_.index, "predicted_answer"] = output["answers"]
        if press:
            df.loc[df_.index, "compression_ratio"] = press.compression_ratio  # type:ignore[attr-defined]
        else:
            df.loc[df_.index, "compression_ratio"] = 0  # type:ignore[attr-defined]
        torch.cuda.empty_cache()

        # restore chat template
        pipe.tokenizer.chat_template = chat_template_bak
        pipe.tokenizer.bos_token = bos_bak
        pipe.model.generation_config.eos_token_id = gen_config_eos_id_bak

    # Save answers
    df[["predicted_answer", "compression_ratio"]].to_csv(str(save_filename), index=False)

    print("Saving DataFrame to", save_filename)

    df.to_csv(str(save_filename).replace(".csv", "_df.csv"), index=False)
    # Calculate metrics
    scorer = SCORER_DICT[dataset]
    metrics = scorer(df)
    with open(str(save_filename).replace(".csv", ".json"), "w") as f:
        json.dump(metrics, f)
    print(f"Average compression ratio: {df['compression_ratio'].mean():.2f}")
    print(f"Failure count: {Failure_count}")
    print(metrics)


if __name__ == "__main__":
    Fire(evaluate)
