import copy
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, List
from types import MethodType
import torch
import torch.nn.functional as F

from transformers.modeling_utils import PreTrainedModel
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList, MinLengthLogitsProcessor
from transformers.generation.candidate_generator import CandidateGenerator

from ...config import capture_config
from mllmsd.modules.sps.candidate_generator_vlmsd import AssistedCandidateGeneratorVLM, AssistedCandidateGeneratorCascadeVLM

def _get_candidate_generator_vlm(
    self,
    generation_config: GenerationConfig,
    input_ids: torch.LongTensor,
    inputs_tensor: torch.Tensor,
    assistant_model: "PreTrainedModel",
    logits_processor: LogitsProcessorList,
    model_kwargs: Dict,
) -> CandidateGenerator:
    """
    Returns the candidate generator to be used in `assisted_generation`
    """
    

    drafting = assistant_model._config['drafting']
    
    if isinstance(drafting, list):
        candidate_generator = AssistedCandidateGeneratorCascadeVLM(
            input_ids=input_ids,
            assistant_model=assistant_model,
            generation_config=generation_config,
            model_kwargs=model_kwargs,
            inputs_tensor=inputs_tensor,
            logits_processor=logits_processor,
            prompt_setter=assistant_model.prompt_setter
        )
    else:
        candidate_generator = AssistedCandidateGeneratorVLM(
            input_ids=input_ids,
            assistant_model=assistant_model,
            generation_config=generation_config,
            model_kwargs=model_kwargs,
            inputs_tensor=inputs_tensor,
            logits_processor=logits_processor,
            prompt_setter=assistant_model.prompt_setter,
        )

    return candidate_generator

def _validate_assistant_vlm(self, assistant_model):
    if assistant_model is None:
        return

    if self.config.is_encoder_decoder and not assistant_model.config.is_encoder_decoder:
        attributes_to_check = ["encoder_attention_heads", "encoder_ffn_dim", "encoder_layers"]
        attributes_to_check = [attr for attr in dir(assistant_model.config) if attr in attributes_to_check]
        are_equal = all(
            getattr(self.config, attr) == getattr(assistant_model.config, attr) for attr in attributes_to_check
        )
        if not are_equal:
            raise ValueError(
                "The main model and the assistant don't have compatible encoder-dependent input shapes. "
                "Ensure you load the assistant with the correct encoder-decoder class, e.g. `AutoModelForSpeechSeq2Seq` for Whisper."
            )

    if hasattr(assistant_model.config, "vocab_size"):
        # Text-only drafting
        if not self.config.text_config.vocab_size == assistant_model.config.vocab_size:
            raise ValueError("Make sure the main and assistant model use the same tokenizer")
    else:
        # Multimodal drafting
        if not self.config.text_config.vocab_size == assistant_model.config.text_config.vocab_size:
            raise ValueError("Make sure the main and assistant model use the same tokenizer")