import torch

from transformers.utils import logging

from transformers.generation import GenerationMixin
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList

logger = logging.get_logger(__name__)

@torch.no_grad()
def generate(
    model: GenerationMixin,
    cfg_scale: float = None,
    enhance_query: bool = False,
    **kwargs,
):
    """better generate function: directly use model.generate, avoid calling private API.

    Only return necessary output_ids, satisfy downstream decode needs.
    """
    # extract control parameters
    max_new_tokens = kwargs.pop("max_new_tokens", None)
    return_scores = kwargs.pop("return_scores", False)  # placeholder, not used
    tokenizer = kwargs.pop("tokenizer", None)  # unused, kept for interface compatibility

    # keep input tensor unchanged, pass to model.generate
    model_inputs = kwargs

    # generation parameters (no sampling, greedy)
    gen_kwargs = {}
    if max_new_tokens is not None:
        gen_kwargs["max_new_tokens"] = max_new_tokens
    gen_kwargs.setdefault("do_sample", False)
    gen_kwargs.setdefault("use_cache", True)

    outputs = model.generate(**model_inputs, **gen_kwargs)

    # compatible with different Transformers versions
    output_ids = outputs.sequences if hasattr(outputs, "sequences") else outputs

    return dict(output_ids=output_ids)

def sample(
    model: GenerationMixin,
    input_ids: torch.LongTensor,
    stopping_criteria: StoppingCriteriaList,
    generation_config: GenerationConfig,
    cfg_scale: float = None,
    enhance_query: bool = False,
    return_scores: bool = False,
    topk: int = 10,
    **model_kwargs,
):
    # init values
    if return_scores:
        assert input_ids.size(0) == 1, "Only batch size 1 is supported in return_scores mode."
        topk_scores, topk_indices = [], []
        value_scores, sub_value_scores = [], []
        value_indices = torch.arange(15, 25).tolist() # according to the Qwen Codebook
        log_prob_list = []
    # compatible with different implementations
    pad_token_id = getattr(generation_config, "_pad_token_tensor", None)
    if pad_token_id is None:
        pad_id_int = getattr(generation_config, "pad_token_id", None)
        if pad_id_int is None:
            pad_id_int = 0
        pad_token_id = torch.tensor(pad_id_int, device=input_ids.device)
    max_length = generation_config.max_length
    has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)

    # keep track of which sequences are already finished
    batch_size, cur_len = input_ids.shape
    this_peer_finished = False
    unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
    # version compatible: different transformers implementations have different signatures
    try:
        # common signature: (input_ids, model_kwargs)
        model_kwargs = model._get_initial_cache_position(input_ids, model_kwargs)
    except TypeError:
        try:
            # another signature: (model_kwargs)
            model_kwargs = model._get_initial_cache_position(model_kwargs)
        except TypeError:
            # some implementations may not need explicit initialization, keep original
            model_kwargs = model_kwargs

    while model._has_unfinished_sequences(
        this_peer_finished, False, device=input_ids.device, cur_len=cur_len, max_length=max_length
    ):
        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
        outputs = model(**model_inputs, return_dict=True)

        # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
        model_kwargs = model._update_model_kwargs_for_generation(
            outputs,
            model_kwargs,
            is_encoder_decoder=model.config.is_encoder_decoder,
        )

        # NOTE: Clone is needed
        next_token_logits = outputs.logits[:, -1, :].clone().float()
        next_token_logits = next_token_logits.to(input_ids.device)

        # token selection
        if cfg_scale is not None:
            next_token_logits_cond = next_token_logits[0:1]
            next_token_logits_uncond = next_token_logits[1:2]
            next_token_logits = (
                next_token_logits_uncond + cfg_scale * (
                    next_token_logits_cond - next_token_logits_uncond
                )
            )
            next_tokens = torch.argmax(next_token_logits, dim=-1).expand(2)
        elif enhance_query:
            num_enhance = next_token_logits.shape[0]
            next_token_logits = next_token_logits.mean(dim=0, keepdim=True)
            next_tokens = torch.argmax(next_token_logits, dim=-1).expand(num_enhance)
        else:
            next_tokens = torch.argmax(next_token_logits, dim=-1)
        
        if return_scores:
            if cfg_scale is not None:
                _ref = next_token_logits[0:1]
            else:
                _ref = next_token_logits
            cur_log_prob = torch.log_softmax(_ref / 2, dim=-1).max().item()
            log_prob_list.append(cur_log_prob)
            if not next_tokens[0].item() in value_indices:
                if len(sub_value_scores) > 0:
                    sub_value_scores = torch.stack(sub_value_scores, dim=1)
                    value_scores.append(sub_value_scores)
                sub_value_scores = []
            else:
                sub_value_scores.append(_ref[:, value_indices].detach().clone())
            cur_topk_score, cur_topk_index = torch.topk(_ref, topk, dim=-1)
            topk_scores.append(cur_topk_score)
            topk_indices.append(cur_topk_index)

 
        if has_eos_stopping_criteria:
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

  
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

        unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
        this_peer_finished = unfinished_sequences.max() == 0
        cur_len += 1
        del outputs
    
    if return_scores:
        topk_scores = torch.cat(topk_scores, dim=0).tolist()
        topk_indices = torch.cat(topk_indices, dim=0).tolist()
        return dict(
            output_ids=input_ids,
            topk_scores=topk_scores,
            topk_indices=topk_indices,
        )
    else:
        return dict(
            output_ids=input_ids,
        )