from collections.abc import Iterable
from typing import Optional, Union

import torch, unicodedata
from torch import nn
from modelscope import AutoTokenizer
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler, Sampler, _apply_top_k_top_p
from vllm.model_executor.layers.sampler import *
from vllm.model_executor.layers.sampler import _apply_min_tokens_penalty, _apply_top_k_top_p, _apply_min_p, _sample, _build_sampler_output
import torch.distributed as dist

# from line_profiler import profile

def last_not_value(lst, value):
    for item in reversed(lst):
        if item != value:
            return item
    return lst[-1]

@torch.no_grad()
def batch_varied_top_k_p_in_logits(
    logits_batch: torch.Tensor,      # [N, V]
    lang_masks:   torch.Tensor,      # [C, V] bool
    top_ks:       torch.Tensor,      # [N] long
    top_ps:       torch.Tensor,      # [N] float
) -> torch.Tensor:
    """
    Vectorized version of batch_only_low_res_in_logits that handles
    different top_k and top_p values for each sequence in the batch.

    Returns
    -------
    labels : BoolTensor of shape [N, C]
    """
    N, V = logits_batch.shape
    C    = lang_masks.shape[0]
    device = logits_batch.device
    
    # Use the maximum top_k for the initial top-k operation
    max_k = int(top_ks.max())

    # 1) Top-k logits / indices (for max_k)
    topk_logits, topk_idx = torch.topk(logits_batch,
                                       k=max_k,
                                       dim=-1,
                                       sorted=True)  # [N, max_k]

    # 2) Probabilities of those k tokens
    # TODO: may this be different from topk?
    log_denom = torch.logsumexp(logits_batch.float(), dim=-1, keepdim=True)
    topk_probs = torch.exp((topk_logits.float() - log_denom))
    topk_probs = topk_probs.to(logits_batch.dtype) # [N, max_k]

    # 3) Create masks for sequence-specific top_k
    # This mask is True for tokens within the specific top_k of each sequence
    k_mask = torch.arange(max_k, device=device).expand(N, -1) < top_ks.unsqueeze(-1)
    
    # Apply k_mask to probabilities before nucleus sampling
    topk_probs.masked_fill_(~k_mask, 0.0)

    # 4) Nucleus (top_p) filter inside the k tokens
    cum_p = topk_probs.cumsum(dim=-1) # [N, max_k]
    
    # Remove tokens with cumulative probability > top_p
    # Keep the first token that exceeds the threshold
    prev_cum_p = torch.zeros_like(cum_p)
    prev_cum_p[:, 1:] = cum_p[:, :-1]
    
    # Skip nucleus filtering when top_p == 1.0
    p_mask = prev_cum_p < top_ps.unsqueeze(-1)  # [N, max_k] < [N, 1] -> [N, max_k]
    top_p_eq_1_mask = (top_ps == 1.0).unsqueeze(-1)  # [N] -> [N, 1]
    p_mask = p_mask | top_p_eq_1_mask  # [N, max_k] | [N, 1] -> [N, max_k] (broadcasting)
    
    # Final keep mask combines k and p filtering
    keep_mask = k_mask & p_mask # [N, max_k]

    # 5) Low-res check on the filtered tokens
    token_is_lang = lang_masks[:, topk_idx]             # [C, N, max_k]
    token_is_lang = token_is_lang.permute(1, 2, 0)      # [N, max_k, C]
    keep_mask_exp = keep_mask.unsqueeze(-1)             # [N, max_k, 1]
    any_lang = (keep_mask_exp & token_is_lang).any(dim=1) # [N, C]

    return any_lang

## this expects logits_batch to be processed such that non topk topp tokens are -inf
@torch.no_grad()
def batch_varied_top_k_p_in_logits_opt(
    logits_batch: torch.Tensor,      # [N, V]
    lang_masks:   torch.Tensor,      # [C, V] bool
) -> torch.Tensor:
    keep_mask_exp = ~torch.isneginf(logits_batch).unsqueeze(1) # [N, 1, V]
    lang_masks_exp = lang_masks.unsqueeze(0) # [1, C, V]
    any_lang = (keep_mask_exp & lang_masks_exp).any(dim=2) # [N, C]
    return any_lang

def contains_cj(text):
    for char in text:
        if char == ' ':
            continue
        try:
            name = unicodedata.name(char)
        except ValueError:
            continue
        if not unicodedata.category(char).startswith('L'):
            continue
        if 'CJK' in name or 'KATAKANA' in name or 'HIRAGANA' in name:
            return 1
    return 0

def contains_latin(text):
    for char in text:
        if char == ' ':
            continue
        try:
            name = unicodedata.name(char)
        except ValueError:
            continue
        if not unicodedata.category(char).startswith('L'):
            continue
        if 'LATIN' in name:
            return 1
    return 0

def contains_only_special(text):
    for char in text:
        if char.isspace():
            continue
        try:
            name = unicodedata.name(char)
        except ValueError:
            continue
        if unicodedata.category(char).startswith('L'):
            return 0
    return 1

tokenizer = AutoTokenizer.from_pretrained('./cs_gate_train/models/turbo-nothink-gate-qwen3-controlfix-20k_95p_flores_2025-08-17-21:06:11_plugged')

def get_lang_masks(vocab_size):
    cj_labels, latin_labels, special_labels, lowres_labels = [], [], [], []
    for tok in range(vocab_size):
        tok_text = tokenizer.decode(tok)
        cj_label = contains_cj(tok_text)
        latin_label = contains_latin(tok_text)
        special_label = contains_only_special(tok_text)
        cj_labels.append(cj_label)
        latin_labels.append(latin_label)
        special_labels.append(special_label)
        lowres_labels.append((cj_label == 0) and (latin_label == 0) and (special_label == 0))
    current_device = torch.cuda.current_device()
    lang_masks = torch.tensor([cj_labels, latin_labels, special_labels, lowres_labels]).to(current_device)
    return lang_masks

def print_logits(logits):
    next_token_logits = logits[-1, :]
    top_k_tokens = torch.topk(next_token_logits, 20, dim=-1).indices
    for i, token_id in enumerate(top_k_tokens):
        token_str = tokenizer.decode(token_id.item())
        print(f"Top {i+1}: {[token_str]}", end = ', ')


def get_control_toks(vocab_size):
    special_toks = []
    for vs in tokenizer.special_tokens_map.values():
        if type(vs) is str:
            assert(len(tokenizer.encode(vs)) == 1)
            special_toks.append(tokenizer.encode(vs)[0])
        else:
            for v in vs:
                assert(len(tokenizer.encode(v)) == 1)
                special_toks.append(tokenizer.encode(v)[0])
    control_toks_mask = torch.zeros(vocab_size)
    for tok in special_toks:
        control_toks_mask[tok] = 1
    control_toks_mask = torch.tensor(control_toks_mask).bool()
    return control_toks_mask


class GatingMixin:
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config)
        print("after init")
        self.packed_modules_mapping = {
            "qkv_proj": [
                "q_proj",
                "k_proj",
                "v_proj",
            ],
            "gate_up_proj": [
                "gate_proj",
                "up_proj",
            ],
        }
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config

        self.code_switch_pre = ColumnParallelLinear(config.hidden_size, config.hidden_size, quant_config=quant_config, return_bias=False)
        self.code_switch_head = RowParallelLinear(config.hidden_size, 4, quant_config=quant_config, return_bias=False)
        self._lang_head = nn.Sequential(
            self.code_switch_pre,
            nn.ReLU(),
            self.code_switch_head,
        )
        lang_masks = get_lang_masks(self.logits_processor.vocab_size)
        print('before', lang_masks[0].sum(), lang_masks[1].sum(), lang_masks[2].sum(), lang_masks[3].sum())
        control_toks_mask = get_control_toks(self.logits_processor.vocab_size)
        lang_masks[:, control_toks_mask] = 0
        lang_masks[2, control_toks_mask] = 1
        print('after', lang_masks[0].sum(), lang_masks[1].sum(), lang_masks[2].sum(), lang_masks[3].sum())
        assert((lang_masks.int().sum(dim=0) == 1).all())
        tok_to_lang = lang_masks.int().argmax(dim=0)
        self.tok_to_lang = tok_to_lang
        self.lang_masks = lang_masks
        self.sampler = GateSampler(self.tok_to_lang, self.lang_masks)

    # @profile
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata,
    ) -> Optional[torch.Tensor]:
        device = torch.cuda.current_device()
        batch_logits = self.logits_processor(
            self.lm_head, hidden_states, sampling_metadata
        )
        original_hidden_state = hidden_states.clone()
        if sampling_metadata is not None:
            hidden_states = hidden_states.index_select(0, sampling_metadata.selected_token_indices)

        hidden_lang = self.code_switch_pre.quant_method.apply(
            self.code_switch_pre, hidden_states, bias=self.code_switch_pre.bias
        )
        hidden_lang = nn.functional.relu(hidden_lang)
        batch_lang_logits = self.code_switch_head.quant_method.apply(
            self.code_switch_head, hidden_lang, bias=self.code_switch_head.bias
        )
        return batch_logits, batch_lang_logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loaded_params = super().load_weights(weights)
        # print("loaded_params", loaded_params)
        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["lm_head."])
        # TODO: a bit hacky, need to fix this
        params_to_ignore = {'code_switch_head.bias', 'code_switch_pre.bias'}
        return loader.load_weights(weights).union(loaded_params).union(params_to_ignore)

    def sample(
        self,
        logits: Optional[torch.Tensor],
        batch_lang_logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, batch_lang_logits, sampling_metadata)
        return next_tokens

    @classmethod
    def is_backend_compatible(cls) -> bool:
        return True  # Or implement actual logic if needed

class GateSampler(Sampler):
    def __init__(self, tok_to_lang, lang_masks):
        super().__init__()
        self.tok_to_lang = tok_to_lang
        self.lang_masks = lang_masks

    def apply_lang_mask(
        self,
        logits_top_k_top_p: torch.Tensor,
        logits_top_k_top_p_double: torch.Tensor,
        batch_lang_logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        device = torch.cuda.current_device()
        batch_logits = logits_top_k_top_p.clone()
        batch_logits2 = logits_top_k_top_p_double.clone()

        original_logits_arr = []
        original_logits_arr2 = []
        top_k_arr = []
        top_p_arr = []
        last_tok_lang_arr = []
        logit_indices_arr = []
        lang_logits_arr = []
        token_ids_arr = []

        last_token_ids_list = []
        indices_mapping = []

        prompt_logprob_all_None = True
        for seq_group in sampling_metadata.seq_groups:
            prompt_logprob_all_None = prompt_logprob_all_None and (seq_group.sampling_params.prompt_logprobs is None)
            sampling_params = seq_group.sampling_params
            top_p = sampling_params.top_p
            top_k = sampling_params.top_k
            if seq_group.do_sample:
                seq_ids_list = seq_group.seq_ids
                num_seqs_in_group = len(seq_ids_list)
                for idx_within_group, logit_index_in_selected in enumerate(seq_group.sample_indices):
                    if idx_within_group < num_seqs_in_group:
                        corresponding_seq_id = seq_ids_list[idx_within_group]
                    else:
                        print(f"Warning: Index mismatch. idx_within_group={idx_within_group}, num_seqs={num_seqs_in_group}")
                        continue

                    seq_data: SequenceData = seq_group.seq_data[corresponding_seq_id]
                    last_token_ids = seq_data.get_token_ids()[-10:]

                    last_token_ids_list.append(last_token_ids)
                    indices_mapping.append((seq_group, idx_within_group, logit_index_in_selected))

                    original_logits_arr.append(batch_logits[logit_index_in_selected])
                    original_logits_arr2.append(batch_logits2[logit_index_in_selected])
                    lang_logits_arr.append(batch_lang_logits[logit_index_in_selected])
                    top_k_arr.append(top_k if top_k != -1 else batch_logits.shape[1])
                    top_p_arr.append(top_p)
                    token_ids_arr.append(seq_data)
                    logit_indices_arr.append(logit_index_in_selected)

        if last_token_ids_list:
            batch_size = len(last_token_ids_list)

            buf = torch.empty(
                (batch_size, 10),
                dtype=torch.long,
                pin_memory=True,
                device='cpu'
            )

            for i, last_token_ids in enumerate(last_token_ids_list):
                buf[i].copy_(torch.as_tensor(last_token_ids, dtype=torch.long))

            last_token_ids_tensor = buf.to(device, non_blocking=True)
            tok_langs_tensor = self.tok_to_lang[last_token_ids_tensor]

            # # GPU 上处理 tok_langs_tensor 生成 last_tok_lang_arr
            # for i, (seq_group, idx_within_group, logit_index_in_selected) in enumerate(indices_mapping):
            #     tok_langs = tok_langs_tensor[i]
            #     last_tok_lang = last_not_value(tok_langs, 2)
            #     last_tok_lang_arr.append(last_tok_lang)

            seq_len = tok_langs_tensor.shape[1]
            mask = (tok_langs_tensor != 2)
            has_non_value = mask.any(dim=1)

            rev_mask = mask.flip(dims=[1])
            offsets = rev_mask.float().argmax(dim=1)
            last_pos = seq_len - 1 - offsets

            # 修正全 False 行的 last_pos
            last_pos = torch.where(has_non_value, last_pos, torch.full_like(last_pos, seq_len - 1))
            last_pos = torch.clamp(last_pos, 0, seq_len - 1)

            # 安全 gather
            batch_idx = torch.arange(tok_langs_tensor.size(0), device=tok_langs_tensor.device)
            last_tok_lang_tensor = tok_langs_tensor[batch_idx, last_pos]

            last_tok_lang_arr = list(last_tok_lang_tensor)

        selected_logits = torch.stack(original_logits_arr)
        selected_logits2 = torch.stack(original_logits_arr2)
        selected_lang_logits = torch.stack(lang_logits_arr)

        top_ks = torch.full((len(top_k_arr), ), 
                            top_k if top_k != -1 else batch_logits.shape[1], 
                            device=device, dtype=torch.long)
        top_ps = torch.full((len(top_p_arr), ), sampling_params.top_p, device=device, dtype=torch.float)

        last_tok_langs = torch.stack(last_tok_lang_arr)
        if prompt_logprob_all_None:
            logit_indices = torch.arange(len(batch_logits), device=batch_logits.device, dtype=torch.long)
        else:
            logit_indices = torch.tensor(logit_indices_arr, device=device, dtype=torch.long)
            

        # 2. Perform batched calculations
        # Check which language categories are present in the sampling-filtered logits
        # logits_lang_labels_sampling = batch_varied_top_k_p_in_logits(
        #     selected_logits, self.lang_masks, top_ks, top_ps
        # )
        logits_lang_labels_sampling = batch_varied_top_k_p_in_logits_opt(selected_logits, self.lang_masks)
        logits_lang_labels_topk = batch_varied_top_k_p_in_logits_opt(selected_logits2, self.lang_masks) | logits_lang_labels_sampling

        pred_lang_labels = selected_lang_logits.sigmoid() >= 0.5
        max_idx = selected_lang_logits.sigmoid().argmax(dim=-1, keepdim=True)
        best_is_true = torch.zeros_like(pred_lang_labels).scatter_(-1, max_idx, 1).bool()
        pred_lang_labels = pred_lang_labels | best_is_true

        original_pred_lang_labels = pred_lang_labels.clone()
        pred_lang_labels.scatter_(dim=1, index=last_tok_langs.unsqueeze(-1), value=True)
        # no intervention if last is symbol
        mask = (last_tok_langs == 2)
        pred_lang_labels[mask, :] = True

        # Determine which languages to mask out
        langs_to_mask = ~pred_lang_labels
        langs_to_mask[:, [2, 3]] = False  # Never mask 'special' or 'low-res' categories

        
        # 3. Determine which sequences need intervention
        intervention_has_effect = (logits_lang_labels_sampling & langs_to_mask).any(dim=1)
        target_in_topk = (logits_lang_labels_topk & original_pred_lang_labels).any(dim=1)
        apply_intervention = intervention_has_effect & target_in_topk

        # check if only symbol allowed for debugging
        symbol_true_lang_labels = original_pred_lang_labels.clone()
        symbol_true_lang_labels[:, 2] = True
        only_symbol_allowed = symbol_true_lang_labels.sum(dim=-1) == 1

        original_batch_logits = batch_logits.clone()
        if apply_intervention.any():
            indices_to_modify = logit_indices[apply_intervention]
            logits_to_modify = selected_logits[apply_intervention]
            final_langs_to_mask = langs_to_mask[apply_intervention]
            vocab_mask = torch.matmul(final_langs_to_mask.float(), self.lang_masks.float()).bool()
            batch_logits[indices_to_modify] = logits_to_modify.masked_fill(vocab_mask, -torch.inf)

        return batch_logits

    def forward(
        self,
        logits: torch.Tensor,
        batch_lang_logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        """
        Single-step scheduling:
        * Perform GPU-side sampling computation & compute
          GPU-side logprobs tensor
        * Pythonize sampling result & logprobs tensor

        Multi-step scheduling:
        * Perform GPU-side sampling computation & compute
          GPU-side logprobs tensor
        * Defer Pythonization of sampling result & logprobs
          tensor
        * Encapsulate arguments required for deferred Pythonization
          in the :class:`SamplerOutput` structure

        Args:
            logits: (num_tokens, vocab_size).
            sampling_metadata: Metadata for sampling.
        """
        assert logits is not None
        _, vocab_size = logits.shape
        device = torch.cuda.current_device()

        # Prepare sampling tensors with pinned memory to avoid blocking.
        if not sampling_metadata.reuse_sampling_tensors:
            self._init_sampling_tensors(logits, sampling_metadata)
        elif self._do_penalties:
            # In this case, the sampling tensors logic depends on
            # "output_tokens" of a sequence. As a result, we cannot
            # reuse sampling tensors, since "output_tokens" changes
            # between decode runs.
            self._init_sampling_tensors(logits, sampling_metadata)

        assert self._sampling_tensors is not None
        sampling_tensors = self._sampling_tensors
        do_penalties = self._do_penalties
        do_top_p_top_k = self._do_top_p_top_k
        do_min_p = self._do_min_p

        logits = _apply_min_tokens_penalty(logits, sampling_metadata)

        # Apply presence and frequency penalties.
        if do_penalties:
            logits = apply_penalties(logits, sampling_tensors.prompt_tokens,
                                     sampling_tensors.output_tokens,
                                     sampling_tensors.presence_penalties,
                                     sampling_tensors.frequency_penalties,
                                     sampling_tensors.repetition_penalties)

        # Use float32 to apply temperature scaling.
        # Use in-place division to avoid creating a new tensor.
        logits = logits.to(torch.float)
        logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

        if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
            logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
                                        sampling_tensors.top_ks)
            logits2 = _apply_top_k_top_p(logits, torch.full(sampling_tensors.top_ps.size(), 0.999, device=logits.device),
                                        torch.full(sampling_tensors.top_ks.size(), 5, device=logits.device))
            logits3 = _apply_top_k_top_p(logits, torch.full(sampling_tensors.top_ps.size(), 0.95, device=logits.device),
                                        torch.full(sampling_tensors.top_ks.size(), 20, device=logits.device))
            logits = self.apply_lang_mask(logits, logits2, batch_lang_logits, sampling_metadata)
            

        if do_min_p:
            logits = _apply_min_p(logits, sampling_tensors.min_ps)

        # We use float32 for probabilities and log probabilities.
        # Compute the probabilities.
        probs = torch.softmax(logits, dim=-1, dtype=torch.float)
        # Compute the log probabilities.
        logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)

        # Sample the next tokens.
        maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample(
            probs,
            logprobs,
            sampling_metadata,
            sampling_tensors,
            include_gpu_probs_tensor=self.include_gpu_probs_tensor,
            modify_greedy_probs=self._should_modify_greedy_probs_inplace,
        )

        if self.include_gpu_probs_tensor:
            # Since we will defer sampler result Pythonization,
            # preserve GPU-side tensors in support of later
            # deferred pythonization of logprobs
            assert maybe_sampled_tokens_tensor is not None
            on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
        else:
            # Since Pythonization has already happened, don't preserve
            # GPU-side tensors.
            on_device_tensors = None

        # Get the logprobs query results.
        prompt_logprobs = None
        sample_logprobs = None
        if not sampling_metadata.skip_sampler_cpu_output:
            # Pythonize logprobs now (GPU -> CPU); do not defer.
            assert not isinstance(maybe_deferred_sample_results,
                                  SampleResultArgsType)
            prompt_logprobs, sample_logprobs = get_logprobs(
                logprobs, sampling_metadata, maybe_deferred_sample_results)

        return _build_sampler_output(
            maybe_deferred_sample_results,
            sampling_metadata,
            prompt_logprobs,
            sample_logprobs,
            on_device_tensors=on_device_tensors,
            skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output)

class Qwen3Gating(GatingMixin, Qwen3ForCausalLM):
    pass

class Qwen3MoeGating(GatingMixin, Qwen3MoeForCausalLM):
    pass