import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
from torch import nn

from transformers.generation.logits_process import (
    LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
    StoppingCriteria,
    StoppingCriteriaList,
    validate_stopping_criteria,
)
import transformers
from transformers.generation.utils import SampleOutput



def _expand_attn_mask_to_cache(model_kwargs, input_ids):
    """Expand attention_mask to match a pre-filled KV cache length.

    When VCD+DSCR starts generation from a pre-filled cache (e.g. 632 entries for
    a 633-token image-expanded sequence), the attention_mask coming from generate()
    is still sized to the raw text token count (e.g. 58).  LlamaModel's
    _prepare_decoder_attention_mask expands the 2-D mask to 4-D using its own
    last dimension, producing (1,1,1,58) which mismatches kv_seq_len=633 and
    raises a ValueError.  This helper pre-expands the mask to (batch, cache_len+1)
    so every forward step starts with a correctly-sized mask.
    """
    cache = model_kwargs.get("past_key_values")
    if cache is None:
        return
    try:
        cache_len = cache[-1][-1].shape[-2]
    except (IndexError, TypeError, AttributeError):
        return
    target = cache_len + 1
    attn = model_kwargs.get("attention_mask")
    if attn is not None and attn.shape[-1] != target:
        model_kwargs["attention_mask"] = torch.ones(
            (input_ids.shape[0], target), dtype=attn.dtype, device=attn.device
        )


def sample(
    self,
    input_ids: torch.LongTensor,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    logits_warper: Optional[LogitsProcessorList] = None,
    max_length: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_scores: Optional[bool] = None,
    return_dict_in_generate: Optional[bool] = None,
    synced_gpus: bool = False,
    streamer: Optional["BaseStreamer"] = None,
    **model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
    # init values
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    if max_length is not None:
        warnings.warn(
            "`max_length` is deprecated in this function, use"
            " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
            UserWarning,
        )
        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
    logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
    pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id


    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
    output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
    output_attentions = (
        output_attentions if output_attentions is not None else self.generation_config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
    )

    return_dict_in_generate = (
        return_dict_in_generate
        if return_dict_in_generate is not None
        else self.generation_config.return_dict_in_generate
    )

    # init attention / hidden states / scores tuples
    scores = () if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # keep track of which sequences are already finished
    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)

    this_peer_finished = False  # used by synced_gpus only
    
    # VCD + DSCR: Check if model has pre-computed CD cache (set before generate call)
    _using_dscr_cd_cache = False
    _using_dscr_cache = model_kwargs.get("past_key_values") is not None
    
    # CRITICAL: Initialize cache_position BEFORE copying model_kwargs
    # This ensures proper position calculation when using pre-filled cache (DSCR)
    if hasattr(self, "_get_initial_cache_position"):
        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
    
    model_kwargs_cd = model_kwargs.copy() # copy model_kwargs for cd only for the first forward process
    
    if _using_dscr_cache:
        # DSCR cache already contains image embeddings - remove image inputs to avoid duplication
        model_kwargs.pop("pixel_values", None)
        model_kwargs.pop("images", None)
        model_kwargs.pop("image_grid_thw", None)
    
    if hasattr(self, "_dscr_past_key_values_cd") and self._dscr_past_key_values_cd is not None:
        # Use the pre-computed CD cache (DSCR-refined noisy cache) for the CD path
        model_kwargs_cd["past_key_values"] = self._dscr_past_key_values_cd
        delattr(self, "_dscr_past_key_values_cd")
        _using_dscr_cd_cache = True
        # CD cache already contains image embeddings - remove image inputs to avoid duplication
        model_kwargs_cd.pop("images_cd", None)
        model_kwargs_cd.pop("pixel_values_cd", None)
        model_kwargs_cd.pop("image_grid_thw_cd", None)
    elif _using_dscr_cache:
        # CRITICAL: CD path must NOT inherit DSCR clean cache!
        # If no separate CD cache was provided, remove past_key_values from CD kwargs
        # so VCD recomputes noisy logits from images_cd (preserves contrastive signal).
        model_kwargs_cd.pop("past_key_values", None)
        model_kwargs_cd.pop("cache_position", None)

    # LLaVA 1.6 / VCD+DSCR: expand attention_mask to match the pre-filled cache
    # length before generation begins (see detailed comment in _expand_attn_mask_to_cache).
    if _using_dscr_cache:
        _expand_attn_mask_to_cache(model_kwargs, input_ids)
    if _using_dscr_cd_cache:
        _expand_attn_mask_to_cache(model_kwargs_cd, input_ids)

    # auto-regressive generation
    while True:
        if synced_gpus:
            # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
            # The following logic allows an early break if all peers finished generating their sequence
            this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
            # send 0.0 if we finished, 1.0 otherwise
            dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
            # did all peers finish? the reduced sum will be 0.0 then
            if this_peer_finished_flag.item() == 0.0:
                break

        # prepare model inputs
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

        # forward pass to get next token
        outputs = self(
            **model_inputs,
            return_dict=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        if synced_gpus and this_peer_finished:
            continue  # don't waste resources running the code we don't need

        next_token_logits = outputs.logits[:, -1, :]
        

        ## For contrastive decoding initial
        use_cd = model_kwargs.get("images_cd") != None
        output_attentions_wo_img = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states_wo_img = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )
        

        if use_cd:
            if model_kwargs_cd.get("image_grid_thw_cd") is not None:
                model_kwargs_cd["attention_mask"] = None
            ## cd_comments: forward pass of the model with distorted image input
            model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
            outputs_cd = self(
                **model_inputs_cd,
                return_dict=True,
                output_attentions=output_attentions_wo_img,
                output_hidden_states=output_hidden_states_wo_img,
            )
            next_token_logits_cd = outputs_cd.logits[:, -1, :]
            
            ## cd_comments: pre-process logits from contrastive inputs
            cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 0.5
            cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.1
            
            # version 1  set cutoff for Adaptive Plausibility Constraints
            # probs = nn.functional.softmax(next_token_logits, dim=-1)
            # cutoff = cd_beta * probs.max(dim=-1, keepdim=True).values

            # version 2 set cutoff for Adaptive Plausibility Constraints
            cutoff = torch.log(torch.tensor(cd_beta)) + next_token_logits.max(dim=-1, keepdim=True).values
            
            diffs = (1+cd_alpha)*next_token_logits - cd_alpha*next_token_logits_cd
            cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))

            ## cd_comments: apply temperature warping and top-k filtering in contrastive decoding
            cd_logits = logits_processor(input_ids, cd_logits)
            cd_logits = logits_warper(input_ids, cd_logits)

            cd_logits = torch.nan_to_num(cd_logits, nan=-float("inf"), posinf=1e4, neginf=-1e4)

            next_token_scores = cd_logits
            if bool(getattr(self.generation_config, "do_sample", True)):
                cd_probs = nn.functional.softmax(cd_logits, dim=-1)
                next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(cd_logits, dim=-1)
        else:
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)
            if bool(getattr(self.generation_config, "do_sample", True)):
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(next_token_scores, dim=-1)



        # Store scores, attentions and hidden_states when required
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)

            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )


        # finished sentences should have their next token be a padding token
        if eos_token_id is not None:
            if pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # update generated ids, model inputs, and length for next step
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())
        if "cache_position" not in model_kwargs:
            model_kwargs["cache_position"] = torch.arange(
                input_ids.shape[1], device=input_ids.device, dtype=torch.long
            )
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
        )
        ## cd_comments: update model_kwargs_cd for contrastive decoding
        if use_cd:
            if "cache_position" not in model_kwargs_cd:
                model_kwargs_cd["cache_position"] = torch.arange(
                    input_ids.shape[1], device=input_ids.device, dtype=torch.long
                )
            if model_kwargs_cd.get("attention_mask") is None and model_kwargs.get("attention_mask") is not None:
                model_kwargs_cd["attention_mask"] = model_kwargs["attention_mask"]
            model_kwargs_cd = self._update_model_kwargs_for_generation(
                outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
            )

        # if eos_token was found in one sentence, set sentence to finished
        if eos_token_id_tensor is not None:
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )

            # stop when each sentence is finished
            if unfinished_sequences.max() == 0:
                this_peer_finished = True

        # stop if we exceed the maximum length
        if stopping_criteria(input_ids, scores):
            this_peer_finished = True

        if this_peer_finished and not synced_gpus:
            break

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        if self.config.is_encoder_decoder:
            return SampleEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        else:
            return SampleDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
            )
    else:
        return input_ids

def evolve_vcd_sampling():
    transformers.generation.utils.GenerationMixin.sample = sample
    # sample is now a protected function in the latest Transformers library
    transformers.generation.utils.GenerationMixin._sample = sample
    transformers.generation.utils.GenerationMixin.greedy_search = greedy_search_vcd
    if hasattr(transformers.generation.utils.GenerationMixin, "_greedy_search"):
        transformers.generation.utils.GenerationMixin._greedy_search = greedy_search_vcd


def greedy_search_vcd(
    self,
    input_ids: torch.LongTensor,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    max_length: Optional[int] = None,
    pad_token_id: Optional[int] = None,
    eos_token_id: Optional[Union[int, List[int]]] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_scores: Optional[bool] = None,
    return_dict_in_generate: Optional[bool] = None,
    synced_gpus: bool = False,
    streamer: Optional["BaseStreamer"] = None,
    **model_kwargs,
):
    logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
    stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
    if max_length is not None:
        warnings.warn(
            "`max_length` is deprecated in this function, use"
            " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
            UserWarning,
        )
        stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
    pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
    eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id

    if isinstance(eos_token_id, int):
        eos_token_id = [eos_token_id]
    eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
    output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
    output_attentions = (
        output_attentions if output_attentions is not None else self.generation_config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
    )
    return_dict_in_generate = (
        return_dict_in_generate
        if return_dict_in_generate is not None
        else self.generation_config.return_dict_in_generate
    )

    scores = () if (return_dict_in_generate and output_scores) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
    this_peer_finished = False

    _using_dscr_cd_cache = False
    _using_dscr_cache = model_kwargs.get("past_key_values") is not None

    if hasattr(self, "_get_initial_cache_position"):
        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
    model_kwargs_cd = model_kwargs.copy()

    if _using_dscr_cache:
        model_kwargs.pop("pixel_values", None)
        model_kwargs.pop("images", None)
        model_kwargs.pop("image_grid_thw", None)

    if hasattr(self, "_dscr_past_key_values_cd") and self._dscr_past_key_values_cd is not None:
        model_kwargs_cd["past_key_values"] = self._dscr_past_key_values_cd
        delattr(self, "_dscr_past_key_values_cd")
        _using_dscr_cd_cache = True
        model_kwargs_cd.pop("images_cd", None)
        model_kwargs_cd.pop("pixel_values_cd", None)
        model_kwargs_cd.pop("image_grid_thw_cd", None)
    elif _using_dscr_cache:
        model_kwargs_cd.pop("past_key_values", None)
        model_kwargs_cd.pop("cache_position", None)

    # LLaVA 1.6 / VCD+DSCR: expand attention_mask to match the pre-filled cache
    # length before generation begins (see _expand_attn_mask_to_cache for details).
    if _using_dscr_cache:
        _expand_attn_mask_to_cache(model_kwargs, input_ids)
    if _using_dscr_cd_cache:
        _expand_attn_mask_to_cache(model_kwargs_cd, input_ids)

    while True:
        if synced_gpus:
            this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
            dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
            if this_peer_finished_flag.item() == 0.0:
                break

        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
        outputs = self(
            **model_inputs,
            return_dict=True,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        if synced_gpus and this_peer_finished:
            continue

        next_token_logits = outputs.logits[:, -1, :]
        use_cd = model_kwargs.get("images_cd") is not None or _using_dscr_cd_cache
        output_attentions_wo_img = (
            output_attentions if output_attentions is not None else self.generation_config.output_attentions
        )
        output_hidden_states_wo_img = (
            output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
        )

        if use_cd:
            if model_kwargs_cd.get("image_grid_thw_cd") is not None:
                model_kwargs_cd["attention_mask"] = None
            model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
            outputs_cd = self(
                **model_inputs_cd,
                return_dict=True,
                output_attentions=output_attentions_wo_img,
                output_hidden_states=output_hidden_states_wo_img,
            )
            next_token_logits_cd = outputs_cd.logits[:, -1, :]

            cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 0.5
            cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.1

            cutoff = torch.log(torch.tensor(cd_beta)) + next_token_logits.max(dim=-1, keepdim=True).values
            diffs = (1 + cd_alpha) * next_token_logits - cd_alpha * next_token_logits_cd
            cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))
            cd_logits = logits_processor(input_ids, cd_logits)
            cd_logits = torch.nan_to_num(cd_logits, nan=-float("inf"), posinf=1e4, neginf=-1e4)

            next_token_scores = cd_logits
            next_tokens = torch.argmax(cd_logits, dim=-1)
        else:
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_tokens = torch.argmax(next_token_scores, dim=-1)

        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)
            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        if eos_token_id is not None:
            if pad_token_id is None:
                raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())
        if "cache_position" not in model_kwargs:
            model_kwargs["cache_position"] = torch.arange(
                input_ids.shape[1], device=input_ids.device, dtype=torch.long
            )
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
        )
        if use_cd:
            if "cache_position" not in model_kwargs_cd:
                model_kwargs_cd["cache_position"] = torch.arange(
                    input_ids.shape[1], device=input_ids.device, dtype=torch.long
                )
            if model_kwargs_cd.get("attention_mask") is None and model_kwargs.get("attention_mask") is not None:
                model_kwargs_cd["attention_mask"] = model_kwargs["attention_mask"]
            model_kwargs_cd = self._update_model_kwargs_for_generation(
                outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
            )

        if eos_token_id_tensor is not None:
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )
            if unfinished_sequences.max() == 0:
                this_peer_finished = True

        if stopping_criteria(input_ids, scores):
            this_peer_finished = True
        if this_peer_finished and not synced_gpus:
            break

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        from transformers.generation.utils import GreedySearchDecoderOnlyOutput, GreedySearchEncoderDecoderOutput

        if self.config.is_encoder_decoder:
            return GreedySearchEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        return GreedySearchDecoderOnlyOutput(
            sequences=input_ids,
            scores=scores,
            attentions=decoder_attentions,
            hidden_states=decoder_hidden_states,
        )
    return input_ids


# ---------------------------------------------------------------------------
# Qwen2.5-VL compatible VCD sampling  (transformers >=4.49)
# ---------------------------------------------------------------------------
# The _sample() signature changed in transformers 4.49:
#   _sample(self, input_ids, logits_processor, stopping_criteria,
#           generation_config, synced_gpus, streamer, **model_kwargs)
#
# The old sample() expected logits_warper, pad_token_id, eos_token_id, etc.
# as separate kwargs – which are now bundled into generation_config.
# Using the old function causes:
#   - empty logits_warper  (temperature/top-p never applied in warper slot)
#   - generation_config leaking into model_kwargs → passed through to
#     prepare_inputs_for_generation → may cause silent shape issues
#   - always multinomial sampling regardless of do_sample
#   - pad/eos resolved from *model* generation_config instead of the
#     one constructed by generate()
# The cumulative effect on Qwen2.5-VL is that VCD drastically over-corrects
# and produces near-random / all-"no" answers.
# ---------------------------------------------------------------------------

def sample_qwen25(
    self,
    input_ids: torch.LongTensor,
    logits_processor: "LogitsProcessorList",
    stopping_criteria: "StoppingCriteriaList",
    generation_config: "GenerationConfig",
    synced_gpus: bool,
    streamer: Optional["BaseStreamer"],
    **model_kwargs,
) -> Union["GenerateNonBeamOutput", torch.LongTensor]:
    """VCD-aware _sample() that matches the transformers >=4.49 interface.
    Use *only* for Qwen2.5-VL (patched via evolve_vcd_sampling_qwen25).
    """

    # ---- extract params from generation_config (mirrors original _sample) --
    pad_token_id = generation_config._pad_token_tensor
    output_attentions = generation_config.output_attentions
    output_hidden_states = generation_config.output_hidden_states
    output_scores = generation_config.output_scores
    output_logits = getattr(generation_config, "output_logits", False)
    return_dict_in_generate = generation_config.return_dict_in_generate
    max_length = generation_config.max_length
    do_sample = generation_config.do_sample
    has_eos_stopping_criteria = any(
        hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
    )

    # ---- init score / attention / hidden-state tuples -------------------
    scores = () if (return_dict_in_generate and output_scores) else None
    raw_logits = () if (return_dict_in_generate and output_logits) else None
    decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
    cross_attentions = () if (return_dict_in_generate and output_attentions) else None
    decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

    if return_dict_in_generate and self.config.is_encoder_decoder:
        encoder_attentions = (
            model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
        )
        encoder_hidden_states = (
            model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
        )

    # ---- sequence tracking ------------------------------------------------
    batch_size, cur_len = input_ids.shape
    this_peer_finished = False
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)

    # ---- VCD + DSCR cache handling ----------------------------------------
    _using_dscr_cd_cache = False
    # Detect pre-computed DSCR cache: generate() always creates an empty
    # DynamicCache, so we must check that the cache actually contains data.
    _pkv = model_kwargs.get("past_key_values")
    _using_dscr_cache = (
        _pkv is not None
        and hasattr(_pkv, "get_seq_length")
        and _pkv.get_seq_length() > 0
    )

    # cache_position – must be set before we copy model_kwargs
    model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

    model_kwargs_cd = model_kwargs.copy()

    # DynamicCache (transformers>=4.49) is mutated in-place by the model's
    # forward pass.  A shallow dict copy shares the same cache object, so
    # the clean forward would contaminate the CD cache.  Give the CD path
    # its own independent cache so both paths build KV entries separately.
    _cd_pkv = model_kwargs_cd.get("past_key_values")
    if _cd_pkv is not None and hasattr(_cd_pkv, "get_seq_length"):
        if _cd_pkv.get_seq_length() == 0:
            # Empty cache (non-DSCR) → create a fresh empty cache for CD
            from transformers import DynamicCache
            model_kwargs_cd["past_key_values"] = DynamicCache()
        else:
            # Non-empty cache (DSCR) → deep copy so clean & CD stay separate
            model_kwargs_cd["past_key_values"] = copy.deepcopy(_cd_pkv)

    if _using_dscr_cache:
        model_kwargs.pop("pixel_values", None)
        model_kwargs.pop("images", None)
        model_kwargs.pop("image_grid_thw", None)

    if hasattr(self, "_dscr_past_key_values_cd") and self._dscr_past_key_values_cd is not None:
        _dscr_cd_pkv = self._dscr_past_key_values_cd
        # transformers >=4.49 requires a Cache object; convert legacy tuple if needed
        if isinstance(_dscr_cd_pkv, tuple):
            from transformers import DynamicCache
            _dc = DynamicCache()
            for layer_idx, (k, v) in enumerate(_dscr_cd_pkv):
                _dc.update(k, v, layer_idx)
            _dscr_cd_pkv = _dc
        model_kwargs_cd["past_key_values"] = _dscr_cd_pkv
        delattr(self, "_dscr_past_key_values_cd")
        _using_dscr_cd_cache = True
        model_kwargs_cd.pop("images_cd", None)
        model_kwargs_cd.pop("pixel_values_cd", None)
        model_kwargs_cd.pop("image_grid_thw_cd", None)
    elif _using_dscr_cache:
        model_kwargs_cd.pop("past_key_values", None)
        model_kwargs_cd.pop("cache_position", None)

    # ---- helpers for stopping (same as 4.49 _sample) ----------------------
    def _has_unfinished(peer_finished):
        if synced_gpus:
            flag = torch.tensor(0.0 if peer_finished else 1.0).to(input_ids.device)
            dist.all_reduce(flag, op=dist.ReduceOp.SUM)
            return flag.item() != 0.0
        return not peer_finished

    # ---- auto-regressive generation loop ----------------------------------
    is_prefill = True
    while _has_unfinished(this_peer_finished) and cur_len < max_length:
        # --- clean path ---------------------------------------------------
        model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
        model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
        model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

        if is_prefill:
            outputs = self(**model_inputs, return_dict=True)
            is_prefill = False
        else:
            outputs = self(**model_inputs, return_dict=True)

        # update clean model_kwargs FIRST (mirrors 4.49 _sample ordering)
        model_kwargs = self._update_model_kwargs_for_generation(
            outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
        )

        if synced_gpus and this_peer_finished:
            cur_len += 1
            continue

        next_token_logits = outputs.logits[:, -1, :].clone().float()
        next_token_logits = next_token_logits.to(input_ids.device)

        # --- contrastive decoding (CD) path --------------------------------
        use_cd = model_kwargs.get("images_cd") is not None or _using_dscr_cd_cache
        if use_cd:
            # Qwen2.5-VL: remove attention_mask for CD when image_grid_thw_cd
            # is present (model recalculates internally after image expansion)
            if model_kwargs_cd.get("image_grid_thw_cd") is not None:
                model_kwargs_cd["attention_mask"] = None

            model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
            model_inputs_cd.update({"output_attentions": output_attentions} if output_attentions else {})
            model_inputs_cd.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
            outputs_cd = self(**model_inputs_cd, return_dict=True)

            next_token_logits_cd = outputs_cd.logits[:, -1, :].clone().float()
            next_token_logits_cd = next_token_logits_cd.to(input_ids.device)

            # VCD formula
            cd_alpha = model_kwargs.get("cd_alpha", 0.5)
            cd_beta = model_kwargs.get("cd_beta", 0.1)

            cutoff = torch.log(torch.tensor(cd_beta, device=input_ids.device, dtype=next_token_logits.dtype)) + next_token_logits.max(dim=-1, keepdim=True).values
            diffs = (1.0 + cd_alpha) * next_token_logits - cd_alpha * next_token_logits_cd
            cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))

            # DEBUG: log VCD diagnostics for first 3 tokens
            if cur_len - input_ids.shape[1] + 1 <= 3:
                _topk_clean = torch.topk(next_token_logits[0], 5)
                _topk_cd = torch.topk(next_token_logits_cd[0], 5)
                _topk_diffs = torch.topk(diffs[0], 5)
                _n_survive = (next_token_logits[0] >= cutoff[0, 0]).sum().item()
                _topk_final = torch.topk(cd_logits[0][cd_logits[0] > -float("inf")], min(5, _n_survive)) if _n_survive > 0 else None
                print(f"[VCD-DBG step={cur_len - input_ids.shape[1] + 1}] "
                      f"cutoff={cutoff[0,0].item():.2f} survive={_n_survive} "
                      f"clean_top5_ids={_topk_clean.indices.tolist()} clean_top5_vals={[f'{v:.2f}' for v in _topk_clean.values.tolist()]} "
                      f"cd_top5_ids={_topk_cd.indices.tolist()} cd_top5_vals={[f'{v:.2f}' for v in _topk_cd.values.tolist()]} "
                      f"diffs_top5_ids={_topk_diffs.indices.tolist()} diffs_top5_vals={[f'{v:.2f}' for v in _topk_diffs.values.tolist()]}")

            # Apply the logits_processor (which already includes temperature / top-p / top-k)
            cd_logits = logits_processor(input_ids, cd_logits)
            cd_logits = torch.nan_to_num(cd_logits, nan=-float("inf"), posinf=1e4, neginf=-1e4)

            next_token_scores = cd_logits

            # Token selection – respect do_sample
            if do_sample:
                cd_probs = nn.functional.softmax(cd_logits, dim=-1)
                next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(cd_logits, dim=-1)

            # update CD model_kwargs
            if model_kwargs_cd.get("attention_mask") is None and model_kwargs.get("attention_mask") is not None:
                model_kwargs_cd["attention_mask"] = model_kwargs["attention_mask"]
            model_kwargs_cd = self._update_model_kwargs_for_generation(
                outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
            )
        else:
            # No CD – standard sampling / greedy
            next_token_scores = logits_processor(input_ids, next_token_logits)
            if do_sample:
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(next_token_scores, dim=-1)

        # --- store diagnostics ---------------------------------------------
        if return_dict_in_generate:
            if output_scores:
                scores += (next_token_scores,)
            if output_logits:
                raw_logits += (next_token_logits,)
            if output_attentions:
                decoder_attentions += (
                    (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                )
                if self.config.is_encoder_decoder:
                    cross_attentions += (outputs.cross_attentions,)
            if output_hidden_states:
                decoder_hidden_states += (
                    (outputs.decoder_hidden_states,)
                    if self.config.is_encoder_decoder
                    else (outputs.hidden_states,)
                )

        # --- pad finished sequences ----------------------------------------
        if has_eos_stopping_criteria:
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

        # --- update input_ids -----------------------------------------------
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        if streamer is not None:
            streamer.put(next_tokens.cpu())

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
        this_peer_finished = unfinished_sequences.max() == 0
        cur_len += 1

        # free large tensors
        del outputs
        if use_cd:
            del outputs_cd

    if streamer is not None:
        streamer.end()

    if return_dict_in_generate:
        # Import the output class used by 4.49
        try:
            from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput
        except ImportError:
            # Fallback for older names
            from transformers.generation.utils import SampleDecoderOnlyOutput as GenerateDecoderOnlyOutput
            from transformers.generation.utils import SampleEncoderDecoderOutput as GenerateEncoderDecoderOutput

        if self.config.is_encoder_decoder:
            return GenerateEncoderDecoderOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                encoder_attentions=encoder_attentions,
                encoder_hidden_states=encoder_hidden_states,
                decoder_attentions=decoder_attentions,
                cross_attentions=cross_attentions,
                decoder_hidden_states=decoder_hidden_states,
            )
        else:
            return GenerateDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                logits=raw_logits,
                attentions=decoder_attentions,
                hidden_states=decoder_hidden_states,
                past_key_values=model_kwargs.get("past_key_values"),
            )
    else:
        return input_ids


def evolve_vcd_sampling_qwen25():
    """Patch _sample with the Qwen2.5-VL / transformers>=4.49 compatible version.
    Call this INSTEAD OF evolve_vcd_sampling() when using Qwen2.5-VL."""
    transformers.generation.utils.GenerationMixin._sample = sample_qwen25