# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
KVPress Pipeline for ManifoldKV
Simplified version for ICML 2026 reproduction
"""

import contextlib
import logging
from typing import Optional

import torch
from transformers import AutoModelForCausalLM, Cache, DynamicCache, Pipeline, QuantizedCache
from transformers.pipelines import PIPELINE_REGISTRY
from transformers.pipelines.base import GenericTensor

from kvpress.presses.base_press import BasePress

logger = logging.getLogger(__name__)


class KVPressTextGenerationPipeline(Pipeline):
    """
    Pipeline for key-value cache compression in causal language models.

    Enables efficient processing of long contexts by applying KV cache compression
    during pre-filling, then generating answers using greedy decoding.

    Example:
    ```python
    pipeline = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
    press = ManifoldKVPress(compression_ratio=0.5)
    result = pipeline(context="Long text...", question="A question about the long context.", press=press)
    ```
    """

    def _sanitize_parameters(
        self,
        question: Optional[str] = None,
        questions: Optional[list[str]] = None,
        answer_prefix: Optional[str] = None,
        press: Optional[BasePress] = None,
        max_new_tokens: int = 50,
        max_context_length: Optional[int] = None,
        cache: Optional[Cache] = None,
        **kwargs,
    ):
        """
        Sanitize the input parameters for the pipeline.
        """
        preprocess_kwargs = {}
        forward_kwargs = {"max_new_tokens": max_new_tokens, "cache": cache}
        postprocess_kwargs = {}

        if question is not None and questions is not None:
            raise ValueError("Cannot provide both 'question' and 'questions'")

        if question is not None:
            questions = [question]

        if questions is not None:
            preprocess_kwargs["questions"] = questions
        else:
            preprocess_kwargs["questions"] = []

        if answer_prefix is not None:
            preprocess_kwargs["answer_prefix"] = answer_prefix

        if press is not None:
            forward_kwargs["press"] = press
            press.post_init_from_model(self.model)

        if max_context_length is not None:
            preprocess_kwargs["max_context_length"] = max_context_length

        return preprocess_kwargs, forward_kwargs, postprocess_kwargs

    def preprocess(
        self,
        context: str,
        questions: list[str],
        answer_prefix: str = "",
        max_context_length: Optional[int] = None,
    ):
        """
        Tokenize the context and questions.
        """
        context_ids = self.tokenizer.encode(context, add_special_tokens=False)

        if max_context_length is not None and len(context_ids) > max_context_length:
            context_ids = context_ids[:max_context_length]

        context_tensor = torch.tensor([context_ids], dtype=torch.long, device=self.model.device)

        questions_ids = [
            self.tokenizer.encode(answer_prefix + q, add_special_tokens=False) for q in questions
        ]

        return {
            "context_ids": context_tensor,
            "questions_ids": questions_ids,
            "answer_prefix": answer_prefix,
        }

    @torch.inference_mode()
    def _forward(
        self,
        model_inputs: dict,
        max_new_tokens: int = 50,
        press: Optional[BasePress] = None,
        cache: Optional[Cache] = None,
    ):
        """
        Run forward pass with KV cache compression.
        """
        context_ids = model_inputs["context_ids"]
        questions_ids = model_inputs["questions_ids"]
        answer_prefix = model_inputs["answer_prefix"]

        # Initialize cache
        if cache is None:
            cache = DynamicCache()

        # Context prefill with compression
        with self._apply_press(press, cache):
            outputs = self.model(
                input_ids=context_ids,
                past_key_values=cache,
                use_cache=True,
                return_dict=True,
            )
            cache = outputs.past_key_values

        answers = []

        for question_ids in questions_ids:
            # Make a copy of the cache for each question
            question_cache = DynamicCache()
            for key, value in zip(cache.key_cache, cache.value_cache):
                question_cache.key_cache.append(key.clone())
                question_cache.value_cache.append(value.clone())

            # Question prefill
            question_tensor = torch.tensor([question_ids], dtype=torch.long, device=self.model.device)
            outputs = self.model(
                input_ids=question_tensor,
                past_key_values=question_cache,
                use_cache=True,
                return_dict=True,
            )
            question_cache = outputs.past_key_values

            # Greedy decoding
            generated_ids = []
            next_token = outputs.logits[:, -1, :].argmax(dim=-1)

            for _ in range(max_new_tokens):
                generated_ids.append(next_token.item())
                
                if next_token.item() == self.tokenizer.eos_token_id:
                    break

                outputs = self.model(
                    input_ids=next_token.unsqueeze(0),
                    past_key_values=question_cache,
                    use_cache=True,
                    return_dict=True,
                )
                question_cache = outputs.past_key_values
                next_token = outputs.logits[:, -1, :].argmax(dim=-1)

            answer = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
            answers.append(answer)

        return {"answers": answers, "answer_prefix": answer_prefix}

    def postprocess(self, model_outputs: dict):
        """
        Format the model outputs.
        """
        if len(model_outputs["answers"]) == 1:
            return {"answer": model_outputs["answers"][0]}
        return {"answers": model_outputs["answers"]}

    @contextlib.contextmanager
    def _apply_press(self, press: Optional[BasePress], cache: Cache):
        """
        Context manager to apply KV cache compression.
        """
        if press is None:
            yield
            return

        hooks = []
        for layer_idx, layer in enumerate(self._get_layers()):
            hook = layer.self_attn.register_forward_hook(
                self._create_compression_hook(press, cache, layer_idx)
            )
            hooks.append(hook)

        try:
            yield
        finally:
            for hook in hooks:
                hook.remove()

    def _create_compression_hook(self, press: BasePress, cache: Cache, layer_idx: int):
        """
        Create a forward hook for KV cache compression.
        """
        def hook(module, args, outputs):
            # Get keys and values from the cache
            if hasattr(cache, 'key_cache') and len(cache.key_cache) > layer_idx:
                keys = cache.key_cache[layer_idx]
                values = cache.value_cache[layer_idx]

                # Get attentions if available
                attentions = outputs[1] if len(outputs) > 1 else None

                # Apply compression
                keys, values = press.compress(
                    module=module,
                    hidden_states=args[0] if args else None,
                    keys=keys,
                    values=values,
                    attentions=attentions,
                    kwargs={},
                )

                # Update cache
                cache.key_cache[layer_idx] = keys
                cache.value_cache[layer_idx] = values

            return outputs

        return hook

    def _get_layers(self):
        """
        Get the transformer layers from the model.
        """
        if hasattr(self.model, 'model') and hasattr(self.model.model, 'layers'):
            return self.model.model.layers
        elif hasattr(self.model, 'transformer') and hasattr(self.model.transformer, 'h'):
            return self.model.transformer.h
        else:
            raise ValueError("Cannot find layers in model")


# Register the pipeline
try:
    PIPELINE_REGISTRY.register_pipeline(
        "kv-press-text-generation",
        pipeline_class=KVPressTextGenerationPipeline,
        pt_model=AutoModelForCausalLM,
    )
except Exception:
    pass  # Already registered
