import warnings
from typing import List, Optional, Union

import torch
import torch.distributed as dist
from torch import nn

import transformers
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import (
    StoppingCriteriaList,
    validate_stopping_criteria,
)
from transformers.generation.utils import SampleDecoderOnlyOutput, SampleEncoderDecoderOutput


def sample(
    self,
    input_ids: torch.LongTensor,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    logits_warper: Optional[LogitsProcessorList] = None,
    max_length: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_scores: Optional[bool] = None,
    return_dict_in_generate: Optional[bool] = None,
    synced_gpus: bool = False,
    streamer: Optional["BaseStreamer"] = None,
    **model_kwargs,
) -> Union[SampleDecoderOnlyOutput, torch.LongTensor]:
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    if max_length is not None:
        warnings.warn(
            "`max_length` is deprecated in this function, use"
            " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
            UserWarning,
        )
        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
    logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
    pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
    output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
    output_attentions = (
        output_attentions if output_attentions is not None else self.generation_config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
    )

    return_dict_in_generate = (
        return_dict_in_generate
        if return_dict_in_generate is not None
        else self.generation_config.return_dict_in_generate
    )

    scores = () if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
    this_peer_finished = False

    model_kwargs_cd = model_kwargs.copy()
    # For AGLA+DSCR: use separate cache for augmented image if provided
    # Check both model_kwargs and model._agla_cache_augmented (temporary storage)
    if "past_key_values_cd" in model_kwargs:
        model_kwargs_cd["past_key_values"] = model_kwargs["past_key_values_cd"]
    elif hasattr(self, '_agla_cache_augmented'):
        model_kwargs_cd["past_key_values"] = self._agla_cache_augmented

    # Track initial input length for first-step debugging
    initial_input_len = input_ids.shape[1]
    debug_first_step = hasattr(self, '_agla_debug_first_step') and getattr(self, '_agla_debug_first_step', False)
    debug_tokenizer = getattr(self, '_agla_debug_tokenizer', None) if debug_first_step else None

    while True:
        if synced_gpus:
            this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
            dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
            if this_peer_finished_flag.item() == 0.0:
                break

        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
        outputs = self(
            **model_inputs,
            return_dict=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        if synced_gpus and this_peer_finished:
            continue

        next_token_logits = outputs.logits[:, -1, :]

        use_cd = model_kwargs.get("images_cd") is not None
        output_attentions_wo_img = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states_wo_img = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )

        if use_cd:
            model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
            outputs_cd = self(
                **model_inputs_cd,
                return_dict=True,
                output_attentions=output_attentions_wo_img,
                output_hidden_states=output_hidden_states_wo_img,
            )
            next_token_logits_cd = outputs_cd.logits[:, -1, :]

            cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 1
            cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.5

            cutoff = torch.log(torch.tensor(cd_beta)) + next_token_logits.max(dim=-1, keepdim=True).values

            diffs = next_token_logits + cd_alpha * next_token_logits_cd
            cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))

            cd_logits = logits_processor(input_ids, cd_logits)
            cd_logits = logits_warper(input_ids, cd_logits)

            # Debug first step: print logits comparison
            is_first_step = debug_first_step and (input_ids.shape[1] == initial_input_len)
            combined_argmax = None
            if is_first_step and debug_tokenizer is not None:
                # Get top-5 tokens for each logits
                def get_topk_tokens_and_scores(logits_tensor, k=5):
                    topk_vals, topk_indices = torch.topk(logits_tensor.squeeze(0), k=k)
                    tokens = [debug_tokenizer.decode([idx.item()]) if hasattr(debug_tokenizer, 'decode') else str(idx.item()) 
                             for idx in topk_indices]
                    return list(zip(tokens, topk_vals.cpu().tolist()))
                
                clean_top5 = get_topk_tokens_and_scores(next_token_logits)
                cd_top5 = get_topk_tokens_and_scores(next_token_logits_cd)
                combined_top5 = get_topk_tokens_and_scores(cd_logits)
                
                clean_argmax = torch.argmax(next_token_logits, dim=-1).item()
                cd_argmax = torch.argmax(next_token_logits_cd, dim=-1).item()
                combined_argmax = torch.argmax(cd_logits, dim=-1).item()
                
                clean_argmax_token = debug_tokenizer.decode([clean_argmax]) if hasattr(debug_tokenizer, 'decode') else str(clean_argmax)
                cd_argmax_token = debug_tokenizer.decode([cd_argmax]) if hasattr(debug_tokenizer, 'decode') else str(cd_argmax)
                combined_argmax_token = debug_tokenizer.decode([combined_argmax]) if hasattr(debug_tokenizer, 'decode') else str(combined_argmax)
                
                print("\n" + "="*80)
                print("[AGLA FIRST STEP DEBUG]")
                print(f"  Clean logits (global) top-5: {clean_top5}")
                print(f"  CD logits (augmented) top-5: {cd_top5}")
                print(f"  Combined AGLA logits top-5: {combined_top5}")
                print(f"  Clean argmax: {clean_argmax_token} (id={clean_argmax})")
                print(f"  CD argmax: {cd_argmax_token} (id={cd_argmax})")
                print(f"  Combined argmax: {combined_argmax_token} (id={combined_argmax})")
                print(f"  cd_alpha={cd_alpha}, cd_beta={cd_beta}")
                print(f"  cutoff={cutoff.item():.4f}")
                print("="*80 + "\n")

            if bool(getattr(self.generation_config, "do_sample", True)):
                cd_probs = nn.functional.softmax(cd_logits, dim=-1)
                next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(cd_logits, dim=-1)
            next_token_scores = cd_logits
            
            # Print selected token in first step
            if is_first_step and debug_tokenizer is not None and combined_argmax is not None:
                selected_token_id = next_tokens.item()
                selected_token = debug_tokenizer.decode([selected_token_id]) if hasattr(debug_tokenizer, 'decode') else str(selected_token_id)
                print(f"[AGLA FIRST STEP DEBUG] Selected token: {selected_token} (id={selected_token_id})")
                print(f"  (Same as combined argmax: {selected_token_id == combined_argmax})\n")
        else:
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)
            if bool(getattr(self.generation_config, "do_sample", True)):
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(next_token_scores, dim=-1)

        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,)
                    if self.config.is_encoder_decoder
                    else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)
            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        if eos_token_id is not None:
            if pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())

        model_kwargs = self._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
        )
        if use_cd:
            model_kwargs_cd = self._update_model_kwargs_for_generation(
                outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
            )

        if eos_token_id_tensor is not None:
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )
            if unfinished_sequences.max() == 0:
                this_peer_finished = True

        if stopping_criteria(input_ids, scores):
            this_peer_finished = True

        if this_peer_finished and not synced_gpus:
            break

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        if self.config.is_encoder_decoder:
            return SampleEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        return SampleDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
    return input_ids


def evolve_agla_sampling():
    transformers.generation.utils.GenerationMixin.sample = sample
    transformers.generation.utils.GenerationMixin._sample = sample
    transformers.generation.utils.GenerationMixin.greedy_search = greedy_search_agla
    if hasattr(transformers.generation.utils.GenerationMixin, "_greedy_search"):
        transformers.generation.utils.GenerationMixin._greedy_search = greedy_search_agla


def greedy_search_agla(
    self,
    input_ids: torch.LongTensor,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    max_length: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_scores: Optional[bool] = None,
    return_dict_in_generate: Optional[bool] = None,
    synced_gpus: bool = False,
    streamer: Optional["BaseStreamer"] = None,
    **model_kwargs,
):
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    if max_length is not None:
        warnings.warn(
            "`max_length` is deprecated in this function, use"
            " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
            UserWarning,
        )
        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
    pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
    output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
    output_attentions = (
        output_attentions if output_attentions is not None else self.generation_config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
    )
    return_dict_in_generate = (
        return_dict_in_generate
        if return_dict_in_generate is not None
        else self.generation_config.return_dict_in_generate
    )

    scores = () if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
    this_peer_finished = False

    model_kwargs_cd = model_kwargs.copy()
    if "past_key_values_cd" in model_kwargs:
        model_kwargs_cd["past_key_values"] = model_kwargs["past_key_values_cd"]
    elif hasattr(self, "_agla_cache_augmented"):
        model_kwargs_cd["past_key_values"] = self._agla_cache_augmented

    while True:
        if synced_gpus:
            this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
            dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
            if this_peer_finished_flag.item() == 0.0:
                break

        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
        outputs = self(
            **model_inputs,
            return_dict=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        if synced_gpus and this_peer_finished:
            continue

        next_token_logits = outputs.logits[:, -1, :]

        use_cd = model_kwargs.get("images_cd") is not None
        output_attentions_wo_img = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states_wo_img = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )

        if use_cd:
            model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
            outputs_cd = self(
                **model_inputs_cd,
                return_dict=True,
                output_attentions=output_attentions_wo_img,
                output_hidden_states=output_hidden_states_wo_img,
            )
            next_token_logits_cd = outputs_cd.logits[:, -1, :]

            cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 1
            cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.5

            cutoff = torch.log(torch.tensor(cd_beta)) + next_token_logits.max(dim=-1, keepdim=True).values
            diffs = next_token_logits + cd_alpha * next_token_logits_cd
            cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))
            cd_logits = logits_processor(input_ids, cd_logits)
            next_token_scores = cd_logits
            next_tokens = torch.argmax(cd_logits, dim=-1)
        else:
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_tokens = torch.argmax(next_token_scores, dim=-1)

        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,)
                    if self.config.is_encoder_decoder
                    else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)
            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        if eos_token_id is not None:
            if pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())

        model_kwargs = self._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
        )
        if use_cd:
            model_kwargs_cd = self._update_model_kwargs_for_generation(
                outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
            )

        if eos_token_id_tensor is not None:
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )
            if unfinished_sequences.max() == 0:
                this_peer_finished = True

        if stopping_criteria(input_ids, scores):
            this_peer_finished = True

        if this_peer_finished and not synced_gpus:
            break

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        from transformers.generation.utils import GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput

        if self.config.is_encoder_decoder:
            return GreedySearchEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        return GreedySearchDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
    return input_ids


# ---------------------------------------------------------------------------
# Qwen2.5-VL compatible AGLA sampling  (transformers >=4.49)
# ---------------------------------------------------------------------------
# The _sample() signature changed in transformers 4.49:
#   _sample(self, input_ids, logits_processor, stopping_criteria,
#           generation_config, synced_gpus, streamer, **model_kwargs)
#
# The old sample() uses logits_warper, pad_token_id, eos_token_id as separate
# kwargs – these are now bundled into generation_config.
# evolve_agla_sampling() only patches `sample` (old interface), which is never
# called by 4.49's generate().  This version patches `_sample`.
# ---------------------------------------------------------------------------

def sample_qwen25(
    self,
    input_ids: torch.LongTensor,
    logits_processor: "LogitsProcessorList",
    stopping_criteria: "StoppingCriteriaList",
    generation_config: "GenerationConfig",
    synced_gpus: bool,
    streamer: Optional["BaseStreamer"],
    **model_kwargs,
) -> Union["GenerateNonBeamOutput", torch.LongTensor]:
    """AGLA-aware _sample() that matches the transformers >=4.49 interface.
    Use *only* for Qwen2.5-VL (patched via evolve_agla_sampling_qwen25).
    AGLA formula: logits = clean + α · augmented  (add, not subtract like VCD).
    """

    # ---- extract params from generation_config (mirrors original _sample) --
    pad_token_id = generation_config._pad_token_tensor
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    output_logits = getattr(generation_config, "output_logits", False)
    return_dict_in_generate = generation_config.return_dict_in_generate
    max_length = generation_config.max_length
    do_sample = generation_config.do_sample
    has_eos_stopping_criteria = any(
        hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
    )

    # ---- init score / attention / hidden-state tuples -------------------
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = (
            model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        )
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # ---- sequence tracking ------------------------------------------------
    batch_size, cur_len = input_ids.shape
    this_peer_finished = False
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)

    # ---- AGLA + DSCR cache handling ----------------------------------------
    # Detect pre-computed DSCR cache: generate() always creates an empty
    # DynamicCache, so we must check that the cache actually contains data.
    _pkv = model_kwargs.get("past_key_values")
    _using_dscr_cache = (
        _pkv is not None
        and hasattr(_pkv, "get_seq_length")
        and _pkv.get_seq_length() > 0
    )

    # Qwen2.5-VL dynamic image tokenization: augmented image may produce a
    # different number of image tokens.  input_ids_cd has the correct token
    # count for the augmented image (passed from mme_qwen25.py).
    input_ids_cd = model_kwargs.pop("input_ids_cd", None)

    # cache_position – must be set before we copy model_kwargs
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

    model_kwargs_cd = model_kwargs.copy()

    # DynamicCache (transformers>=4.49) is mutated in-place by the model's
    # forward pass.  A shallow dict copy shares the same cache object, so
    # the clean forward would contaminate the CD cache.  Give the CD path
    # its own independent cache so both paths build KV entries separately.
    _cd_pkv = model_kwargs_cd.get("past_key_values")
    if _cd_pkv is not None and hasattr(_cd_pkv, "get_seq_length"):
        if _cd_pkv.get_seq_length() == 0:
            from transformers import DynamicCache
            model_kwargs_cd["past_key_values"] = DynamicCache()
        else:
            import copy as copy_module
            model_kwargs_cd["past_key_values"] = copy_module.deepcopy(_cd_pkv)

    # If augmented input_ids has different length, fix CD cache_position
    if input_ids_cd is not None and input_ids_cd.shape[1] != input_ids.shape[1]:
        model_kwargs_cd["cache_position"] = torch.arange(
            0, input_ids_cd.shape[1], device=input_ids.device, dtype=torch.long
        )

    if _using_dscr_cache:
        model_kwargs.pop("pixel_values", None)
        model_kwargs.pop("images", None)
        model_kwargs.pop("image_grid_thw", None)

    # AGLA+DSCR: use separate augmented cache if provided
    _using_dscr_cd_cache = False
    if hasattr(self, "_agla_cache_augmented") and self._agla_cache_augmented is not None:
        model_kwargs_cd["past_key_values"] = self._agla_cache_augmented
        # Don't delete here; mme_qwen25.py cleans up in its finally block
        _using_dscr_cd_cache = True
        model_kwargs_cd.pop("images_cd", None)
        model_kwargs_cd.pop("pixel_values_cd", None)
        model_kwargs_cd.pop("image_grid_thw_cd", None)
    elif _using_dscr_cache:
        # CD path must NOT inherit DSCR clean cache
        model_kwargs_cd.pop("past_key_values", None)
        model_kwargs_cd.pop("cache_position", None)

    # ---- debug tracking ---------------------------------------------------
    initial_input_len = input_ids.shape[1]
    debug_first_step = hasattr(self, "_agla_debug_first_step") and getattr(self, "_agla_debug_first_step", False)
    debug_tokenizer = getattr(self, "_agla_debug_tokenizer", None) if debug_first_step else None

    # ---- helpers for stopping (same as 4.49 _sample) ----------------------
    def _has_unfinished(peer_finished):
        if synced_gpus:
            flag = torch.tensor(0.0 if peer_finished else 1.0).to(input_ids.device)
            dist.all_reduce(flag, op=dist.ReduceOp.SUM)
            return flag.item() != 0.0
        return not peer_finished

    # ---- auto-regressive generation loop ----------------------------------
    is_prefill = True
    while _has_unfinished(this_peer_finished) and cur_len < max_length:
        # --- clean path ---------------------------------------------------
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
        model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
        model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

        if is_prefill:
            outputs = self(**model_inputs, return_dict=True)
            is_prefill = False
        else:
            outputs = self(**model_inputs, return_dict=True)

        # update clean model_kwargs FIRST (mirrors 4.49 _sample ordering)
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
        )

        if synced_gpus and this_peer_finished:
            cur_len += 1
            continue        next_token_logits = outputs.logits[:, -1, :].clone().float()
        next_token_logits = next_token_logits.to(input_ids.device)

        # --- contrastive decoding (AGLA) path --------------------------------
        use_cd = model_kwargs.get("images_cd") is not None or _using_dscr_cd_cache
        if use_cd:
            # Qwen2.5-VL: remove attention_mask for CD when image_grid_thw_cd
            # is present (model recalculates internally after image expansion)
            if model_kwargs_cd.get("image_grid_thw_cd") is not None:
                model_kwargs_cd["attention_mask"] = None

            # Use augmented input_ids if available (different image token count)
            _cd_ids = input_ids_cd if input_ids_cd is not None else input_ids
            model_inputs_cd = self.prepare_inputs_for_generation_cd(_cd_ids, **model_kwargs_cd)
            model_inputs_cd.update({"output_attentions": output_attentions} if output_attentions else {})
            model_inputs_cd.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
            outputs_cd = self(**model_inputs_cd, return_dict=True)

            next_token_logits_cd = outputs_cd.logits[:, -1, :].clone().float()
            next_token_logits_cd = next_token_logits_cd.to(input_ids.device)

            # AGLA formula: clean + α * augmented  (ADD, not subtract like VCD)
            cd_alpha = model_kwargs.get("cd_alpha", 1.0)
            cd_beta = model_kwargs.get("cd_beta", 0.5)

            cutoff = torch.log(torch.tensor(cd_beta, device=input_ids.device, dtype=next_token_logits.dtype)) + next_token_logits.max(dim=-1, keepdim=True).values
            diffs = next_token_logits + cd_alpha * next_token_logits_cd
            cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))

            # Debug first step
            is_first_step = debug_first_step and (input_ids.shape[1] == initial_input_len)
            combined_argmax = None
            if is_first_step and debug_tokenizer is not None:
                def get_topk_tokens_and_scores(logits_tensor, k=5):
                    topk_vals, topk_indices = torch.topk(logits_tensor.squeeze(0), k=k)
                    tokens = [debug_tokenizer.decode([idx.item()]) if hasattr(debug_tokenizer, "decode") else str(idx.item())
                             for idx in topk_indices]
                    return list(zip(tokens, topk_vals.cpu().tolist()))

                clean_top5 = get_topk_tokens_and_scores(next_token_logits)
                cd_top5 = get_topk_tokens_and_scores(next_token_logits_cd)
                combined_top5 = get_topk_tokens_and_scores(cd_logits)
                clean_argmax = torch.argmax(next_token_logits, dim=-1).item()
                cd_argmax = torch.argmax(next_token_logits_cd, dim=-1).item()
                combined_argmax = torch.argmax(cd_logits, dim=-1).item()
                clean_argmax_token = debug_tokenizer.decode([clean_argmax]) if hasattr(debug_tokenizer, "decode") else str(clean_argmax)
                cd_argmax_token = debug_tokenizer.decode([cd_argmax]) if hasattr(debug_tokenizer, "decode") else str(cd_argmax)
                combined_argmax_token = debug_tokenizer.decode([combined_argmax]) if hasattr(debug_tokenizer, "decode") else str(combined_argmax)
                print("\n" + "=" * 80)
                print("[AGLA FIRST STEP DEBUG]")
                print(f"  Clean logits top-5: {clean_top5}")
                print(f"  CD logits (augmented) top-5: {cd_top5}")
                print(f"  Combined AGLA logits top-5: {combined_top5}")
                print(f"  Clean argmax: {clean_argmax_token} (id={clean_argmax})")
                print(f"  CD argmax: {cd_argmax_token} (id={cd_argmax})")
                print(f"  Combined argmax: {combined_argmax_token} (id={combined_argmax})")
                print(f"  cd_alpha={cd_alpha}, cd_beta={cd_beta}")
                print(f"  cutoff={cutoff.item():.4f}")
                print("=" * 80 + "\n")

            # Apply logits_processor (includes temperature / top-p / top-k)
            cd_logits = logits_processor(input_ids, cd_logits)
            cd_logits = torch.nan_to_num(cd_logits, nan=-float("inf"), posinf=1e4, neginf=-1e4)

            next_token_scores = cd_logits

            # Token selection – respect do_sample
            if do_sample:
                cd_probs = nn.functional.softmax(cd_logits, dim=-1)
                next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(cd_logits, dim=-1)

            # Print selected token in first step
            if is_first_step and debug_tokenizer is not None and combined_argmax is not None:
                selected_token_id = next_tokens.item()
                selected_token = debug_tokenizer.decode([selected_token_id]) if hasattr(debug_tokenizer, "decode") else str(selected_token_id)
                print(f"[AGLA FIRST STEP DEBUG] Selected token: {selected_token} (id={selected_token_id})")
                print(f"  (Same as combined argmax: {selected_token_id == combined_argmax})\n")

            # update CD model_kwargs
            if model_kwargs_cd.get("attention_mask") is None and model_kwargs.get("attention_mask") is not None:
                model_kwargs_cd["attention_mask"] = model_kwargs["attention_mask"]
            model_kwargs_cd = self._update_model_kwargs_for_generation(
                outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
            )

            # After the first CD forward (prefill), the augmented cache is built.
            # Subsequent iterations must process the GENERATED tokens (from the
            # clean input_ids which is updated with new tokens each step), NOT
            # re-process the last token of the augmented input_ids.
            # Clearing input_ids_cd makes the loop use `input_ids` for the CD path.
            if input_ids_cd is not None:
                input_ids_cd = None
        else:
            # No CD – standard sampling / greedy
            next_token_scores = logits_processor(input_ids, next_token_logits)
            if do_sample:
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(next_token_scores, dim=-1)

        # --- store diagnostics ---------------------------------------------
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_logits:
                raw_logits += (next_token_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)
            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        # --- pad finished sequences ----------------------------------------
        if has_eos_stopping_criteria:
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # --- update input_ids -----------------------------------------------
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if input_ids_cd is not None:
            input_ids_cd = torch.cat([input_ids_cd, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0
        cur_len += 1

        # free large tensors
        del outputs
        if use_cd:
            del outputs_cd

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        try:
            from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput
        except ImportError:
            from transformers.generation.utils import SampleDecoderOnlyOutput as GenerateDecoderOnlyOutput
            from transformers.generation.utils import SampleEncoderDecoderOutput as GenerateEncoderDecoderOutput

        if self.config.is_encoder_decoder:
            return GenerateEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        else:
            return GenerateDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    else:
        return input_ids


def evolve_agla_sampling_qwen25():
    """Patch _sample with the Qwen2.5-VL / transformers>=4.49 compatible version.
    Call this INSTEAD OF evolve_agla_sampling() when using Qwen2.5-VL."""
    transformers.generation.utils.GenerationMixin._sample = sample_qwen25