import copy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from types import MethodType
import torch
import torch.nn.functional as F
from einops import rearrange

from transformers.modeling_utils import PreTrainedModel
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList, MinLengthLogitsProcessor
from transformers.generation.candidate_generator import CandidateGenerator, _prepare_attention_mask, _prepare_token_type_ids, _crop_past_key_values
from transformers.utils import is_torchdynamo_compiling

from ...config import capture_config
from ..promptsetter import PromptSetter
from .modeling_llava_vlmsd import _merge_input_ids_with_image_features_image_top_k
from ...utils.util import patch_function, noop_context
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration
from transformers.cache_utils import Cache
from ...modules.classifier import Classifier


class CandidateGenerator:
    """Abstract base class for all candidate generators that can be applied during assisted generation."""

    def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
        """
        Fetches the candidates to be tried for the current input.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)

        Return:
            `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
            assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
            vocabulary_size)` containing the logits associated to each candidate.
        """
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
        )

    def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
        """
        Updates the candidate generation strategy based on the outcomes.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
            scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
                Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
                beam search or log softmax for each vocabulary token when using beam search
            num_matches (`int`):
                The number of matches between the candidate sequences and the model predictions.
        """
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
            "`update_candidate_strategy`."
        )
    
    def _set_num_rejected_vlm(self, num_rejected_token_vlm):
        self.assistant_model.num_rejected_token_vlm = num_rejected_token_vlm

    def _get_num_rejected_vlm(self):
        return self.assistant_model.num_rejected_token_vlm
    

class AssistedCandidateGeneratorVLM(CandidateGenerator):
    """
    `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
    candidates through the use of a smaller model. Read the following blog post for more information:
    https://huggingface.co/blog/assisted-generation

    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
        assistant_model (`PreTrainedModel`):
            The model to be used for generating candidates. This model should be smaller than the main model.
        generation_config (`~generation.GenerationConfig`, *optional*):
            The generation configuration to be used as base parametrization for the generation call.
        logits_processor (`LogitsProcessorList`):
            An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
            used to modify the prediction scores of the language modeling head applied at each generation step.
        model_kwargs (`Dict`):
            The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
            model as well.
        inputs_tensor (`torch.Tensor`, *optional*):
            The model input tensor. In encoder-decoder models, this is the encoder input.
    """

    def __init__(
        self,
        input_ids: torch.LongTensor,
        assistant_model: "PreTrainedModel",
        generation_config: "GenerationConfig",
        model_kwargs: Dict,
        inputs_tensor: Optional[torch.Tensor] = None,
        logits_processor: "LogitsProcessorList" = None,
        prompt_setter: "PromptSetter" = None,
    ):
        # VLMSD
        self._config = capture_config()
        
        # Make sure all data at the same device as assistant model
        device = assistant_model.device
        input_ids = input_ids.to(device)
        if inputs_tensor is not None:
            inputs_tensor = inputs_tensor.to(device)

        # Prepare the assistant and the starting number of candidate tokens
        self.assistant_model = assistant_model
        self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens

        # VLMSD - drafting
        self.prompt_setter = prompt_setter
        input_ids = prompt_setter.manipulated_input_ids
        self.assistant_model._config = self._config

        if (self._config['image_top_k_attention'] > 0) or self._config['output_image_attentions']:
            assert self._config['drafting'] in ['multimodal', 'image-pool']
            self.assistant_model._merge_input_ids_with_image_features = MethodType(
                _merge_input_ids_with_image_features_image_top_k, self.assistant_model
            )

        # Set eos in assistant same as in target model
        self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id

        # Prepare the kwargs for the assistant model
        assistant_kwargs = {}
        architecture_draft_model = assistant_model.config.architectures[0]
        if not (architecture_draft_model=="LlavaForConditionalGeneration" or architecture_draft_model.endswith("ForCausalLM")):
            raise ValueError(
                "The assistant model should be a decoder-only model, like `LlavaForConditionalGeneration` or `XXForCausalLM`."
            )
        for key, value in model_kwargs.items():  # deepcopy crashes if we attempt to copy encoder outputs with grads
            if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"):
                if key=='pixel_values' and prompt_setter.drafting in ['text-only', 'tokenized-image', 'special-token', 'caption']:
                    # VLMSD: Text-only drafting: no need to pass pixel_values to the assistant model
                    continue
                if key=='attention_mask' and prompt_setter.attention_mask_initial is not None and value.size(1) != prompt_setter.attention_mask_initial.size(1):
                    value = prompt_setter.attention_mask_initial
                assistant_kwargs[key] = (
                    value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
                )
        if "assistant_encoder_outputs" in model_kwargs:
            assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
        elif assistant_model.config.is_encoder_decoder:
            inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
                inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
            )
            assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
                inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config
            )
        elif "encoder_outputs" in model_kwargs:
            assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
        self.assistant_kwargs = assistant_kwargs

        # Prepare assistant model's keys of inputs
        if assistant_model.config.is_encoder_decoder:
            # both are encoder-decoder
            self.input_ids_key = "decoder_input_ids"
        elif "encoder_outputs" in assistant_kwargs:
            # special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
            self.input_ids_key = "input_ids"
            self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
                "decoder_attention_mask",
                torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
            )
        else:
            # both are decoder-only
            self.input_ids_key = "input_ids"

        # Prepare generation-related options.
        self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        self.generation_config = copy.deepcopy(generation_config)
        self.generation_config.return_dict_in_generate = True
        self.generation_config.output_scores = True

        # Disable sampling -- this implementation of assisted generation/speculative decoding uses the assistant
        # greedily to maximize matches. Disables sampling-related flags to prevent warnings
        self.generation_config.do_sample = False
        for attr in ("temperature", "top_p", "min_p", "typical_p", "top_k", "epsilon_cutoff", "eta_cutoff"):
            setattr(self.generation_config, attr, None)

        # avoid unnecessary warnings that min_length is larger than max_new_tokens
        # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
        self.main_model_min_length = self.generation_config.min_length
        self.generation_config.min_length = 0
        self.generation_config.min_new_tokens = None
        for processor in self.logits_processor:
            if isinstance(processor, MinLengthLogitsProcessor):
                raise ValueError(
                    "Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
                    "Please pass in `min_length` into `.generate()` instead"
                )

        # We need to roll back the cache in assisted generation, only DynamicCache is supported
        self.generation_config.cache_implementation = None

    def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
        """
        Fetches the candidates to be tried for the current input.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)

        Return:
            `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
            assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
            vocabulary_size)` containing the logits associated to each candidate.
        """
        input_ids = input_ids.to(self.assistant_model.device)
        # Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
        max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - input_ids.size(1) - 1)
        min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - input_ids.size(1)), 0)

        # VLMSD - various drafting
        input_ids = self.prompt_setter.get_resulting_input(input_ids)
        new_cur_len = input_ids.shape[-1]
        if max_new_tokens == 0:
            return dict(
                candidate_ids=self.prompt_setter.rollback_to_original_prompt(input_ids),
                candidate_logits=None,
                time_prefill=None,
                num_prefill_tokens=None,
                time_prompt_process=None,
            )

        # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
        # (which implicitly contains the number of accepted candidates from the previous round)
        has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
        if has_past_key_values:
            new_cache_size = self.assistant_kwargs.get("past_key_values")[0][0].shape[2] - self._get_num_rejected_vlm()
            self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
                self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size
            )

            self.assistant_kwargs = _prepare_attention_mask(
                self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
            )
            self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)

        # 2. Forecast next N tokens using the assistant model.
        assistant_generation_kwargs = {
            self.input_ids_key: input_ids,
            "min_new_tokens": min_new_tokens,
            "max_new_tokens": max_new_tokens,
            "generation_config": self.generation_config,
            "logits_processor": self.logits_processor,
            "output_attentions": self._config['output_image_attentions'],
        }

        assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)

        # 3. Update variables for the next round of candidaɢte generation
        self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values

        # 4. Prepare variables for output
        candidate_logits = torch.stack(assistant_output.scores, dim=1)
        candidate_ids = assistant_output.sequences
        
        # VLMSD - drafting: rollback to original prompt
        candidate_ids_rollback = self.prompt_setter.rollback_to_original_prompt(candidate_ids)

        if assistant_output.time_prefill is not None and hasattr(self.assistant_model, "time_prompt_process"):
            time_prompt_process = self.assistant_model.time_prompt_process 
        else:
            time_prompt_process = None

        candidate_outputs = dict(
            candidate_ids=candidate_ids_rollback,
            candidate_logits=candidate_logits,
            time_prefill=assistant_output.time_prefill,
            num_prefill_tokens=assistant_output.num_prefill_tokens,
            time_prompt_process=time_prompt_process,
        )

        # VLMSD - drafting: store attentions for image tokens if needed
        if self._config['output_image_attentions']:
            candidate_outputs['draft_image_attentions'] = self.get_image_attentions(assistant_output.attentions, self.assistant_model.attention_kwargs)
        
        return candidate_outputs

    def get_image_attentions(self, attentions, attention_kwargs):
        image_regions = attention_kwargs['image_regions'][0]  # Indices of image tokens
        
        # Initialize a list to store image attention values
        image_attentions = []

        # Iterate over generated tokens
        for token_attention in attentions:
            # Initialize a list for the current token's image attention values across all layers
            token_image_attentions = []
            
            # Iterate over layers
            for layer_attention in token_attention:
                # `layer_attention` has shape (batch_size, num_heads, generated_length, sequence_length)
                
                # Select the attention values corresponding to `image_regions` indices
                # We index the last dimension (sequence_length) with `image_regions`
                # The resulting shape will be (batch_size, num_heads, generated_length, len(image_regions))
                # Only choose the last token (for prefill)
                layer_image_attention = layer_attention[..., -1:, image_regions]
                
                # Append the result for this layer
                token_image_attentions.append(layer_image_attention.tolist())
            
            # Append the result for this token
            image_attentions.append(tuple(token_image_attentions))
        
        # Return as a tuple of tuples to match the input structure
        return tuple(image_attentions)
    
    def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
        """
        Updates the candidate generation strategy based on the outcomes.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
            scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
                Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
                beam search or log softmax for each vocabulary token when using beam search
            num_matches (`int`):
                The number of matches between the candidate sequences and the model predictions.
        """
        # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
        # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
        # cost of forecasting incorrect assistant tokens.
        if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
            "heuristic",
            "heuristic_transient",
        }:
            if num_matches == int(self.num_assistant_tokens):
                self.num_assistant_tokens += 2.0
            else:
                self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)

class AssistedCandidateGeneratorCascadeVLM(CandidateGenerator):
    def __init__(
        self,
        input_ids: torch.LongTensor,
        assistant_model: "PreTrainedModel",
        generation_config: "GenerationConfig",
        model_kwargs: Dict,
        inputs_tensor: Optional[torch.Tensor] = None,
        logits_processor: "LogitsProcessorList" = None,
        prompt_setter: "PromptSetter" = None,
    ):
        # VLMSD
        self._config = capture_config()
        self.prompt_setter = prompt_setter

        self.generation_config = copy.copy(generation_config)
        self.num_assistant_tokens = generation_config.num_assistant_tokens
        assistant_model.generation_config.num_assistant_tokens = 1 

        self.prompt_setter = prompt_setter

        # Cascade functionality
        self.cascade_rule = self._config['cascade_rule']
        self.mm_weight_policy = self._config['mm_weight_policy']

        self.candidate_generator_cascade = {}
        for drafting in self._config['drafting']:
            self.candidate_generator_cascade[drafting] = AssistedCandidateGeneratorVLM(
                input_ids=input_ids,
                assistant_model=copy.deepcopy(assistant_model),
                generation_config=generation_config,
                model_kwargs=model_kwargs,
                inputs_tensor=inputs_tensor,
                logits_processor=logits_processor,
                prompt_setter=self.set_prompt_setter(prompt_setter, drafting),
            )
        
        if self._config['load_classifier']:
            num_drf = len(self._config['drafting'])
            k = self._config['classifier_top_k']
            self.classifier = Classifier(num_drf, k, self._config).to(assistant_model.device)

    def set_prompt_setter(self, prompt_setter_cascade, drafting):
        prompt_setter = copy.copy(prompt_setter_cascade)
        prompt_setter.drafting = drafting
        prompt_setter.manipulated_input_ids = prompt_setter_cascade.manipulated_input_ids.get(drafting)
        prompt_setter.manipulated_input_ids_length_initial = prompt_setter_cascade.manipulated_input_ids_length_initial.get(drafting)
        prompt_setter.attention_mask_initial = prompt_setter_cascade.attention_mask_initial.get(drafting)
        return prompt_setter

    def get_candidates(self, input_ids: torch.LongTensor) -> Dict[str, Any]:
        inputs = {
            "input_ids": input_ids,
            "logits": None,  # Initialize logits to None
            "logits_total": {},
        }

        eos_token_id = self.generation_config.eos_token_id  # Get the EOS token ID from the generation config

        for i in range(self.num_assistant_tokens):
            if i == 0:
                context = noop_context()
                # if hasattr(self.candidate_generator_cascade['multimodal'].assistant_model, 'num_rejected_token_vlm'):
                #     print("rejected among above", self.candidate_generator_cascade['multimodal'].assistant_model.num_rejected_token_vlm)
                #     print("#"*100)
            else:
                context = patch_function(
                    LlavaForConditionalGeneration, 
                    '_get_initial_cache_position', 
                    _get_initial_cache_position_cascade,
                )
            
            with context:
                next_tokens_outputs = self.get_next_tokens(inputs['input_ids'])

                next_token, best_index = self.pick_next_token(next_tokens_outputs)

                if next_token is None:
                    break                    
                # print(best_index, next_token)

                # Update num_rejected_vlm for the candidate generator
                self._set_num_rejected_vlm(best_index=best_index)

                # Update inputs, including candidate_logits
                inputs = self.update_inputs(inputs, next_token, next_tokens_outputs, best_index)

                # Break if next_token == eos_token
                if (next_token == eos_token_id).all():
                    break

        candidate_outputs = dict(
            candidate_ids=inputs['input_ids'],
            candidate_logits=inputs['logits'],  # Updated to include combined logits of picked tokens
            candidate_logits_total=inputs['logits_total'],  # Include all logits for debugging
            time_prefill=None,
            num_prefill_tokens=None,
        )

        return candidate_outputs

    def get_next_tokens(self, input_ids: torch.LongTensor):
        per_drafting_outputs = {}

        for drafting, candidate_generator in self.candidate_generator_cascade.items():
            candidate_outputs = candidate_generator.get_candidates(input_ids)
            candidate_ids = candidate_outputs['candidate_ids']
            candidate_logits = candidate_outputs['candidate_logits']

            if candidate_logits is None:
                per_drafting_outputs[drafting] = {
                    'next_token_id': None,
                    'next_token_log_prob_value': None,
                    'next_token_logit': None
                }
                continue

            # Get the next token (the last token in candidate_ids)
            next_token_id = candidate_ids[:, -1].unsqueeze(-1)  # shape (batch_size, 1)
            next_token_logit = candidate_logits[:, -1, :]  # shape (batch_size, vocab_size)

            # Convert logits to probabilities
            next_token_log_probs = F.softmax(next_token_logit, dim=-1)
            # Get the log probability assigned to next_token_id
            next_token_log_prob_value = next_token_log_probs.gather(1, next_token_id)  # shape (batch_size, 1)

            per_drafting_outputs[drafting] = {
                'next_token_id': next_token_id,
                'next_token_log_prob_value': next_token_log_prob_value,
                'next_token_logit': next_token_logit  # Include logits for updating inputs
            }

        return per_drafting_outputs

    def pick_next_token(self, next_tokens_outputs):
        per_drafting_outputs = next_tokens_outputs

        # Collect tokens and their log probabilities, ignoring None values
        tokens = []
        logits = []
        log_probs = []
        valid_draftings = []  # Keep track of valid draftings (i.e., where logits are not None)
        for drafting in per_drafting_outputs:
            if per_drafting_outputs[drafting]['next_token_id'] is not None:
                tokens.append(per_drafting_outputs[drafting]['next_token_id'])  # shape (batch_size, 1)
                logits.append(per_drafting_outputs[drafting]['next_token_logit'])  # shape (batch_size, vocab_size)
                log_probs.append(per_drafting_outputs[drafting]['next_token_log_prob_value'])  # shape (batch_size, 1)
                valid_draftings.append(drafting)  # Add only valid draftings

        # Stack tokens and log probabilities along a new dimension if there are valid draftings
        if tokens:
            tokens = torch.cat(tokens, dim=1)  # shape (batch_size, num_draftings)
            log_probs = torch.cat(log_probs, dim=1)  # shape (batch_size, num_draftings)
            logits = torch.stack(logits, dim=1)  # shape (batch_size, num_draftings, vocab_size)

            if self.cascade_rule == 'confidence':
                # Find the token with the highest log probability for each batch element
                best_indices = torch.argmax(log_probs, dim=1)  # shape (batch_size,)
                best_next_tokens = tokens[torch.arange(tokens.size(0)), best_indices].unsqueeze(-1)  # shape (batch_size, 1)

                # Convert best_indices to correspond to valid_draftings
                best_indices = [valid_draftings[idx] for idx in best_indices.tolist()]  # List of drafting names
            
            elif self.cascade_rule == 'dist-sum':
                # Convert logits to probabilities using softmax
                probs = F.softmax(logits, dim=-1)  # shape (batch_size, num_draftings, vocab_size)

                # Sum probabilities across the drafting dimension
                summed_probs = probs.sum(dim=1)  # shape (batch_size, vocab_size)

                # Pick the token with the highest summed probability
                best_indices = valid_draftings # List of all drafting names
                best_next_tokens = torch.argmax(summed_probs, dim=-1).unsqueeze(-1)  # shape (batch_size, 1)
            
            elif self.cascade_rule == 'mm-weight':
                # Find the index of the multimodal drafting in the valid_draftings list
                if 'multimodal' in valid_draftings:
                    dim_1_index_multimodel_drafting = valid_draftings.index('multimodal')
                else:
                    dim_1_index_multimodel_drafting = valid_draftings.index('caption')


                if isinstance(self.mm_weight_policy, int) or isinstance(self.mm_weight_policy, float):
                    mm_weight = self.mm_weight_policy
                
                elif self.mm_weight_policy == 'img-nec':

                    # Apply softmax to logits to get probabilities
                    probs = F.softmax(logits, dim=-1)  # shape (batch_size, num_draftings, vocab_size)

                    # Get the index for the counterpart drafting (the one that is not dim_1_index_multimodel_drafting)
                    ids_prob = {i for i in range(probs.size(1))}  # Create set of drafting indices
                    ids_prob.remove(dim_1_index_multimodel_drafting)  # Remove multimodel drafting index

                    # Ensure there is exactly one counterpart drafting index
                    assert len(ids_prob) == 1, f"Invalid number of draftings: {len(ids_prob)}"
                    dim_1_index_counterpart_drafting = ids_prob.pop()  # Get the counterpart drafting index

                    # Find the best token from counterpart drafting
                    best_indices_img_nec = torch.argmax(probs, dim=-1)  # shape (batch_size, num_draftings)
                    best_token_from_counterpart = best_indices_img_nec[0, dim_1_index_counterpart_drafting]

                    # Get the probability of the best token from multimodel drafting
                    prob_best_token_mm = probs[0, dim_1_index_multimodel_drafting].max().item()
                    prob_indexed_by_best_token_counterpart = probs[0, dim_1_index_multimodel_drafting, best_token_from_counterpart].item()

                    # Calculate the mm_weight
                    if self._config['mm_weight_k'] is not None:
                        mm_weight = 1 + (prob_best_token_mm - prob_indexed_by_best_token_counterpart) * self._config['mm_weight_k']
                    else:
                        mm_weight = 1 + (prob_best_token_mm - prob_indexed_by_best_token_counterpart)

                else:
                    raise ValueError(f"Invalid multimodal weight policy: {self.mm_weight_policy}")
                

                
                # Convert logits to probabilities using softmax
                probs = F.softmax(logits, dim=-1)  # shape (batch_size, num_draftings, vocab_size)

                # Weight the multimodal drafting probabilities
                probs[:, dim_1_index_multimodel_drafting] *= mm_weight

                # Sum probabilities across the drafting dimension
                summed_probs = probs.sum(dim=1)  # shape (batch_size, vocab_size)

                if hasattr(self, 'classifier'):
                    # Get the top-k indices from the summed_probs
                    topk_probs, topk_indices = torch.topk(summed_probs, k=self._config['classifier_top_k'], dim=-1)  # shape (batch_size, top_k)

                    # Expand topk_indices to match the dimensions needed for indexing
                    # topk_indices_expanded: shape (batch_size, num_draftings, top_k)
                    topk_indices_expanded = topk_indices.unsqueeze(1).expand(-1, probs.size(1), -1)

                    # Gather the probabilities of the top-k tokens across each drafting
                    # topk_probs_per_drafting: shape (batch_size, num_draftings, top_k)
                    topk_probs_per_drafting = torch.gather(probs, dim=2, index=topk_indices_expanded)

                    # Rearrange to shape (batch_size, top_k, num_draftings) using einops.rearrange
                    # This avoids the use of permute and fixes the dimension mismatch error
                    topk_probs_per_drafting = rearrange(topk_probs_per_drafting, 'b n k -> b k n')  # shape (batch_size, top_k, num_draftings)

                    # Flatten batch_size and top_k dimensions to create an input for the classifier
                    batch_size, top_k, num_draftings = topk_probs_per_drafting.shape
                    flat_probs_per_drafting = topk_probs_per_drafting.reshape(batch_size * top_k, num_draftings)  # shape (batch_size * top_k, num_draftings)

                    # Classifier forward pass
                    classification_scores = self.classifier(flat_probs_per_drafting.flatten())  # shape (batch_size * top_k, num_classes)

                    # Reshape classification_scores back to (batch_size, top_k, num_classes)
                    classification_scores = classification_scores.reshape(batch_size, top_k, -1)# -1 corresponds to num_classes

                    # Assume the classifier outputs a single score per token (e.g., probability of the positive class)
                    # We'll select the token with the highest classifier score for each example in the batch
                    # if classification_scores.size(-1) == 1:
                    #     # If there's only one class, squeeze the last dimension
                    token_scores = classification_scores.squeeze(-1)  # shape (batch_size, top_k)
                    # else:
                    #     # If multiple classes, you might want to select the score of a specific class or aggregate
                    #     # For example, taking the score for the positive class
                    #     token_scores = classification_scores[:, :, self.positive_class_index]  # Adjust as needed

                    # Find the best token index in the top_k for each batch based on classifier scores
                    best_token_indices = torch.argmax(token_scores, dim=-1)  # shape (batch_size,)

                    # Select the best next tokens from topk_indices
                    best_next_tokens = topk_indices[torch.arange(batch_size), best_token_indices].unsqueeze(-1)  # shape (batch_size, 1)

                    # Optionally, you can keep track of which drafting contributed to the best token
                    best_indices = valid_draftings  # List of all drafting names


                else:
                    # Pick the token with the highest summed probability
                    best_indices = valid_draftings # List of all drafting names
                    best_next_tokens = torch.argmax(summed_probs, dim=-1).unsqueeze(-1)  # shape (batch_size, 1)
                


                
            else:
                raise ValueError(f"Invalid cascade rule: {self.cascade_rule}")

        else:
            # Handle the case where there are no valid tokens (all None)
            best_next_tokens = None
            best_indices = []

        return best_next_tokens, best_indices

    def update_inputs(self, inputs, next_token, next_tokens_outputs, best_index):
        # Update input_ids by concatenating next_token
        input_ids = inputs['input_ids']
        input_ids = torch.cat([input_ids, next_token], dim=-1)
        inputs['input_ids'] = input_ids

        # Update candidate_logits by concatenating logits of picked tokens (only for best_index)
        logits = inputs['logits']
        logits_total = inputs.get('logits_total', {})  # Initialize logits_total if not already present
        selected_logits = []

        # (1) Gather the logits corresponding to the best_index (2) Store all logits
        for drafting, next_token_output in next_tokens_outputs.items():
            next_token_logits = next_token_output['next_token_logit']  # shape (batch_size, vocab_size)
            # Stack logits_total here for each drafting
            if drafting not in logits_total:
                # If this is the first time encountering the drafting, initialize it
                logits_total[drafting] = next_token_logits.unsqueeze(1)  # shape (batch_size, 1, vocab_size)
            else:
                # Otherwise, concatenate the new logits to the previous ones
                logits_total[drafting] = torch.cat([logits_total[drafting], next_token_logits.unsqueeze(1)], dim=1)

            # If the current drafting is in the best_index, save its logits for selection
            if drafting in best_index:
                selected_logits.append(next_token_logits)

        # Stack selected logits from the drafts that are part of best_index
        selected_logits = torch.stack(selected_logits, dim=0)  # shape (1, vocab_size)

        if logits is None:
            logits = selected_logits  # shape (batch_size, 1, vocab_size)
        else:
            logits = torch.cat([logits, selected_logits], dim=1)  # shape (batch_size, sequence, vocab_size)

        # Update the logits and logits_total in the inputs
        inputs['logits'] = logits
        inputs['logits_total'] = logits_total

        return inputs


    def _set_num_rejected_vlm(self, num_rejected_token_vlm=None, best_index=None):
        """
        Sets `num_rejected_token_vlm` for the current instance or for all instances in the cascade.

        Args:
            num_rejected_token_vlm (int): The value to set for `num_rejected_token_vlm`.
            best_index (Optional[List[str]]): If provided, sets `num_rejected_vlm` for each 
                                            `AssistantCandidateGeneratorVLM` in the cascade.
        """
        if best_index is None:
            # Set num_rejected_vlm for the current instance and all in the cascade
            for candidate_generator in self.candidate_generator_cascade.values():
                candidate_generator._set_num_rejected_vlm(num_rejected_token_vlm)
        else:
            # Cascade functionality: Set num_rejected_vlm to 1 for all initially
            for candidate_generator in self.candidate_generator_cascade.values():
                candidate_generator._set_num_rejected_vlm(1)
            
            # Set num_rejected_vlm to 0 for candidate generators whose token was selected
            for drafting in set(best_index):
                candidate_generator = self.candidate_generator_cascade[drafting]
                candidate_generator._set_num_rejected_vlm(0)
    
    def _get_num_rejected_vlm(self):
        for candidate_generator in self.candidate_generator_cascade.values():
            return candidate_generator._get_num_rejected_vlm()


def _get_initial_cache_position_cascade(self, input_ids, model_kwargs):
    """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
    # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
    if "inputs_embeds" in model_kwargs:
        cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
    else:
        cache_length = model_kwargs['past_key_values'].get_seq_length()
        if input_ids.shape[1] > cache_length:
            # prefill
            cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
        else:
            # decode subsequent chunks for VLMSD
            is_fully_accepted = self.num_rejected_token_vlm == 0
            num_additional_cache = 1 if is_fully_accepted else 2
            cache_position = torch.arange(cache_length + num_additional_cache, device=input_ids.device, dtype=torch.int64)

    past_length = 0
    if model_kwargs.get("past_key_values") is not None:
        cache = model_kwargs["past_key_values"]
        past_length = 0
        if not isinstance(cache, Cache):
            past_length = cache[0][0].shape[2]
        elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
            past_length = cache.get_seq_length()

        # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
        # end-to-end compilation will yield bad results because `cache_position` will be incorrect.
        if not is_torchdynamo_compiling():
            cache_position = cache_position[past_length:]

    model_kwargs["cache_position"] = cache_position
    return model_kwargs