# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
import subprocess
import json
import logging
from pathlib import Path
from typing import Optional

import torch
from datasets import load_dataset
from fire import Fire
from infinite_bench.calculate_metrics import calculate_metrics as infinite_bench_scorer
from longbench.calculate_metrics import calculate_metrics as longbench_scorer
from longbench.calculate_metrics import calculate_metrics_e as longbench_scorer_e
from longbenchv2.calculate_metrics import calculate_metrics as longbenchv2_scorer
from loogle.calculate_metrics import calculate_metrics as loogle_scorer
from ruler.calculate_metrics import calculate_metrics as ruler_scorer
from tqdm import tqdm
from transformers import pipeline
from zero_scrolls.calculate_metrics import calculate_metrics as zero_scrolls_scorer

import pickle
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from kvpress import (
    AdaKVPress,
    ChunkKVPress,
    ComposedPress,
    CriticalAdaKVPress,
    CriticalKVPress,
    DuoAttentionPress,
    ExpectedAttentionPress,
    KnormPress,
    ObservedAttentionPress,
    RandomPress,
    SnapKVPress,
    StreamingLLMPress,
    ThinKPress,
    TOVAPress,
    QFilterPress,
    CURPress,
    ExactCURPress,
)
from kvpress import BasePress, ScorerPress
from dataclasses import dataclass
import wandb

#os.environ['WANDB_MODE'] = "offline"

logger = logging.getLogger(__name__)

DATASET_DICT = {
    "loogle": "simonjegou/loogle",
    "ruler": "simonjegou/ruler",
    "zero_scrolls": "simonjegou/zero_scrolls",
    "infinitebench": "MaxJeblick/InfiniteBench",
    "longbench": "Xnhyacinth/LongBench",
    "longbench-e": "Xnhyacinth/LongBench",
    "longbench-v2": "Xnhyacinth/LongBench-v2",
}

SCORER_DICT = {
    "loogle": loogle_scorer,
    "ruler": ruler_scorer,
    "zero_scrolls": zero_scrolls_scorer,
    "infinitebench": infinite_bench_scorer,
    "longbench": longbench_scorer,
    "longbench-e": longbench_scorer_e,
    "longbench-v2": longbenchv2_scorer,
}

PRESS_DICT = {
    "criti_adasnapkv": CriticalAdaKVPress(SnapKVPress()),
    "criti_ada_expected_attention": CriticalAdaKVPress(ExpectedAttentionPress(use_vnorm=False)),
    "criti_snapkv": CriticalKVPress(SnapKVPress()),
    "criti_expected_attention": CriticalKVPress(ExpectedAttentionPress(use_vnorm=False)),
    "adasnapkv": AdaKVPress(SnapKVPress()),
    "ada_expected_attention": AdaKVPress(ExpectedAttentionPress()),
    "ada_cur": AdaKVPress(CURPress(compression_ratio=0.1, num_sinks=4, use_random_leverage=False, leverage_type='value')),
    "expected_attention": ExpectedAttentionPress(),
    "ada_expected_attention_e2": AdaKVPress(ExpectedAttentionPress(epsilon=1e-2)),
    "knorm": KnormPress(),
    "observed_attention": ObservedAttentionPress(),
    "random": RandomPress(),
    "snapkv": SnapKVPress(),
    "streaming_llm": StreamingLLMPress(),
    "think": ThinKPress(),
    "tova": TOVAPress(),
    "duo_attention": DuoAttentionPress(),
    "chunkkv": ChunkKVPress(press=SnapKVPress(), chunk_length=20),
    "qfilter": QFilterPress(),
    "snap_think": ComposedPress([SnapKVPress(), ThinKPress()]),
    "cur": CURPress(compression_ratio=0.1, num_sinks=4, use_random_leverage=False, leverage_type='value'),
    "exactcur": ExactCURPress(compression_ratio=0.1, num_sinks=4, leverage_type='value')
}

class Sinker(ScorerPress):
    def __init__(self, press, compression_ratio, pad_right, n_sink=4, right_sink_length = 4):
        self.press  = press
        self.compression_ratio = compression_ratio
        self.n_sink = n_sink
        self.pad_right = pad_right
        self.right_sink_length = right_sink_length

    def score(
        self,
        module: nn.Module,
        hidden_states: torch.Tensor,
        keys: torch.Tensor,
        values: torch.Tensor,
        attentions: torch.Tensor,
        kwargs: dict,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        try:
            score = self.press.score(module, hidden_states, keys, values, attentions, kwargs)
        except:
            score = self.press.press.score(module, hidden_states, keys, values, attentions, kwargs)
        score[:,:,:self.n_sink] = 1
        if self.pad_right:
            score[:,:,-self.right_sink_length:] = 1

        return score
    
def evaluate(
    dataset: str,
    data_dir: Optional[str] = None,
    model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct",
    device: Optional[str] = None,
    press_name: str = "expected_attention",
    cur_press_random_leverage: bool = False,
    cur_press_leverage_type: str = 'value',
    cur_press_local_window_size: int = 16,
    cur_press_no_local_approximation: bool = False,
    no_wandb_flag: bool = False,
    right_sink: bool = False,
    compression_ratio: float = 0.1,
    fraction: float = 1.0,
    max_new_tokens: Optional[int] = None,
    max_context_length: Optional[int] = None,
    compress_questions: bool = False,
    key_channel_compression_ratio: float = 0.5,
    attn_implementation: str = "flash_attention_2",
    wandb_project_name: str = "KV Compression"
):
    """
    Evaluate a model on a dataset using a press and save the results

    Parameters
    ----------
    dataset : str
        Dataset to evaluate
    data_dir : str, optional
        Subdirectory of the dataset to evaluate, by default None
    model : str, optional
        Model to use, by default "meta-llama/Meta-Llama-3.1-8B-Instruct"
    device : str, optional
        Model device, by default cuda:0 if available else cpu. For multi-GPU use "auto"
    press_name : str, optional
        Press to use (see PRESS_DICT), by default "expected_attention"
    compression_ratio : float, optional
        Compression ratio for the press, by default 0.1
    max_new_tokens : int, optional
        Maximum number of new tokens to generate, by default use the default for the task (recommended)
    fraction : float, optional
        Fraction of the dataset to evaluate, by default 1.0
    max_context_length : int, optional
        Maximum number of tokens to use in the context. By default will use the maximum length supported by the model.
    compress_questions : bool, optional
        Whether to compress the questions as well, by default False
    key_channel_compression_ratio : float, optional
        key Channel Compression ratio for the channel press, by default 0.5
    """

    right_sink = compress_questions

    assert dataset in DATASET_DICT, f"No dataset found for {dataset}"
    assert dataset in SCORER_DICT, f"No scorer found for {dataset}"
    data_dir = str(data_dir) if data_dir else None

    if device is None:
        device = "cuda:0" if torch.cuda.is_available() else "cpu"

    save_dir = Path(__file__).parent / "results"
    save_dir.mkdir(exist_ok=True)
    save_filename = save_dir / (
        "__".join([dataset, data_dir if data_dir else "", model.replace("/", "--"), press_name, str(compression_ratio), str(cur_press_leverage_type), str(cur_press_random_leverage), attn_implementation])
        + ".csv"
    )
    if save_filename.exists():
        logger.warning(f"Results already exist at {save_filename}")

    # Load dataframe
    df = load_dataset(DATASET_DICT[dataset], data_dir=data_dir, split="test").to_pandas()
    if dataset == 'ruler':
        df = df[df.task.str.contains("niah")].reset_index(drop=True)
        print (df.shape)
    
    if fraction < 1.0:
        df = df.sample(frac=fraction, random_state=42)
        save_filename = save_filename.with_name(save_filename.stem + f"__fraction{fraction:.2f}" + save_filename.suffix)

    if max_context_length is not None:
        save_filename = save_filename.with_name(
            save_filename.stem + f"__max_context{max_context_length}" + save_filename.suffix
        )

    if compress_questions:
        df["context"] = df["context"] + df["question"]
        df["question"] = ""
        save_filename = save_filename.with_name(save_filename.stem + "__compressed_questions" + save_filename.suffix)

    # Load press
    if press_name != 'none':
        assert press_name in PRESS_DICT

        press = PRESS_DICT[press_name]

        if press_name == 'cur' or press_name == 'ada_cur':
            press = PRESS_DICT['cur']
            press.use_random_leverage = cur_press_random_leverage
            press.leverage_type = cur_press_leverage_type
            press.local_window_size = cur_press_local_window_size
            press.use_local_approximation = not cur_press_no_local_approximation
        elif press_name == 'exactcur':
            press.leverage_type = cur_press_leverage_type

        if isinstance(press, (DuoAttentionPress)):
            press.head_compression_ratio = compression_ratio
        elif isinstance(press, (ComposedPress)):
            for ps in press.presses:
                if isinstance(ps, (ThinKPress)):
                    ps.key_channel_compression_ratio = key_channel_compression_ratio
                    save_filename = save_filename.with_name(
                        save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
                    )
                else:
                    ps.compression_ratio = compression_ratio  # type:ignore[attr-defined]
        elif isinstance(press, (ThinKPress)):
            press.key_channel_compression_ratio = key_channel_compression_ratio
            save_filename = save_filename.with_name(
                save_filename.stem + f"__channel{key_channel_compression_ratio}" + save_filename.suffix
            )
        else:
            press.compression_ratio = compression_ratio  # type:ignore[attr-defined]

        if press_name not in ['adasnapkv','criti_adasnapkv', 'chunkkv', 'ada_expected_attention', 'ada_cur']:
            press = Sinker(press=press, compression_ratio=compression_ratio, pad_right=right_sink)

    else:
        press = None

    if press_name == 'ada_cur':
        print (press)
        press = AdaKVPress(press)
        print (press)

    if os.path.exists(save_filename) == False:
        # Initialize pipeline with the correct attention implementation
        model_kwargs = {"torch_dtype": "auto"}
        #if isinstance(press, ObservedAttentionPress):
        if press_name == 'observed_attention':
            #print ("I am here")
            model_kwargs["attn_implementation"] = "eager"
            #model_kwargs["output_attentions"] = True
        else:

            try:
                import flash_attn  # noqa: F401

                model_kwargs["attn_implementation"] = "flash_attention_2"
            except ImportError:
                pass

            #model_kwargs["attn_implementation"] = attn_implementation
        
        if device == "auto":
            pipe = pipeline("kv-press-text-generation", model=model, device_map="auto", model_kwargs=model_kwargs)
        else:
            pipe = pipeline("kv-press-text-generation", model=model, device=device, model_kwargs=model_kwargs)

        # Run pipeline on each context
        df["predicted_answer"] = None
        df_context = df.groupby("context")
        assert all(df_context["answer_prefix"].nunique() == 1)

        run_config = {}
        run_config['model_name'] = model
        run_config['press_name'] = press_name
        if press_name != 'none':
            run_config['compression_ratio'] = compression_ratio
        run_config['dataset'] = '{}/{}'.format(dataset, data_dir)
        run_config['use_random_leverage'] = cur_press_random_leverage
        run_config['leverage_type'] = cur_press_leverage_type
        run_config['cur_press_local_window_size'] = cur_press_local_window_size
        run_config['cur_press_no_local_approximation'] = cur_press_no_local_approximation
        run_config['compress_questions'] = compress_questions
        run_config["attn_implementation"] = attn_implementation
        
        if no_wandb_flag == False:
            run = wandb.init(project=wandb_project_name, config=run_config)

        for context, df_ in tqdm(df_context, total=df["context"].nunique()):
            questions = df_["question"].to_list()
            max_new_tokens_ = max_new_tokens if max_new_tokens is not None else df_["max_new_tokens"].iloc[0]
            answer_prefix = df_["answer_prefix"].iloc[0]
            output = pipe(
                context,
                questions=questions,
                answer_prefix=answer_prefix,
                press=press,
                max_new_tokens=max_new_tokens_,
                max_context_length=max_context_length,
            )
            df.loc[df_.index, "predicted_answer"] = output["answers"]
            try:
                df.loc[df_.index, "compression_ratio"] = press.compression_ratio  # type:ignore[attr-defined]
            except:
                df.loc[df_.index, "compression_ratio"] = 0
            torch.cuda.empty_cache()
        
        # Save answers
        #df[["predicted_answer", "compression_ratio"]].to_csv(str(save_filename), index=False)
        df.to_csv(str(save_filename), index=False)

        # Calculate metrics
        scorer = SCORER_DICT[dataset]
        metrics = scorer(df)
        with open(str(save_filename).replace(".csv", ".json"), "w") as f:
            json.dump(metrics, f)
        print(f"Average compression ratio: {df['compression_ratio'].mean():.2f}")
        print(metrics)

        if no_wandb_flag == False:
            run.log({"Accuracy": metrics})
            run.finish()

if __name__ == "__main__":
    Fire(evaluate)
