from collections.abc import Iterable
from typing import Optional, Union

import torch, unicodedata
from torch import nn
from modelscope import AutoTokenizer
from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput

from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
from vllm.model_executor.models.utils import AutoWeightsLoader, maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
import torch.distributed as dist
from .tok_contains_partial_cjk import contain_partial_cjk

def last_not_value(lst, value):
    for item in reversed(lst):
        if item != value:
            return item
    return lst[-1]

@torch.no_grad()
def batch_varied_top_k_p_in_logits(
    logits_batch: torch.Tensor,      # [N, V]
    lang_masks:   torch.Tensor,      # [C, V] bool
    top_ks:       torch.Tensor,      # [N] long
    top_ps:       torch.Tensor,      # [N] float
) -> torch.Tensor:
    """
    Vectorized version of batch_only_low_res_in_logits that handles
    different top_k and top_p values for each sequence in the batch.

    Returns
    -------
    labels : BoolTensor of shape [N, C]
    """
    N, V = logits_batch.shape
    C    = lang_masks.shape[0]
    device = logits_batch.device
    
    # Use the maximum top_k for the initial top-k operation
    max_k = int(top_ks.max())

    # 1) Top-k logits / indices (for max_k)
    topk_logits, topk_idx = torch.topk(logits_batch,
                                       k=max_k,
                                       dim=-1,
                                       sorted=True)  # [N, max_k]

    # 2) Probabilities of those k tokens
    # TODO: may this be different from topk?
    log_denom = torch.logsumexp(logits_batch.float(), dim=-1, keepdim=True)
    topk_probs = torch.exp((topk_logits.float() - log_denom))
    topk_probs = topk_probs.to(logits_batch.dtype) # [N, max_k]

    # 3) Create masks for sequence-specific top_k
    # This mask is True for tokens within the specific top_k of each sequence
    k_mask = torch.arange(max_k, device=device).expand(N, -1) < top_ks.unsqueeze(-1)
    
    # Apply k_mask to probabilities before nucleus sampling
    topk_probs.masked_fill_(~k_mask, 0.0)

    # 4) Nucleus (top_p) filter inside the k tokens
    cum_p = topk_probs.cumsum(dim=-1) # [N, max_k]
    
    # Remove tokens with cumulative probability > top_p
    # Keep the first token that exceeds the threshold
    prev_cum_p = torch.zeros_like(cum_p)
    prev_cum_p[:, 1:] = cum_p[:, :-1]
    
    # Skip nucleus filtering when top_p == 1.0
    p_mask = prev_cum_p < top_ps.unsqueeze(-1)  # [N, max_k] < [N, 1] -> [N, max_k]
    top_p_eq_1_mask = (top_ps == 1.0).unsqueeze(-1)  # [N] -> [N, 1]
    p_mask = p_mask | top_p_eq_1_mask  # [N, max_k] | [N, 1] -> [N, max_k] (broadcasting)
    
    # Final keep mask combines k and p filtering
    keep_mask = k_mask & p_mask # [N, max_k]

    # 5) Low-res check on the filtered tokens
    token_is_lang = lang_masks[:, topk_idx]             # [C, N, max_k]
    token_is_lang = token_is_lang.permute(1, 2, 0)      # [N, max_k, C]
    keep_mask_exp = keep_mask.unsqueeze(-1)             # [N, max_k, 1]
    any_lang = (keep_mask_exp & token_is_lang).any(dim=1) # [N, C]

    return any_lang

def contains_cj(text):
    for char in text:
        if char == ' ':
            continue
        try:
            name = unicodedata.name(char)
        except ValueError:
            continue
        if not unicodedata.category(char).startswith('L'):
            continue
        if 'CJK' in name or 'KATAKANA' in name or 'HIRAGANA' in name:
            return 1
    return 0

def contains_lang_name(text, lang_names):
    assert(type(lang_names) is list)
    assert(len(lang_names) > 0)
    for char in text:
        if char == ' ':
            continue
        try:
            name = unicodedata.name(char)
        except ValueError:
            continue
        if not unicodedata.category(char).startswith('L'):
            continue
        for lang_name in lang_names:
            if lang_name in name:
                return 1
    return 0

def contains_latin(text):
    for char in text:
        if char == ' ':
            continue
        try:
            name = unicodedata.name(char)
        except ValueError:
            continue
        if not unicodedata.category(char).startswith('L'):
            continue
        if 'LATIN' in name:
            return 1
    return 0

def contains_only_special(text):
    for char in text:
        if char.isspace():
            continue
        try:
            name = unicodedata.name(char)
        except ValueError:
            continue
        if unicodedata.category(char).startswith('L'):
            return 0
    return 1

tokenizer = AutoTokenizer.from_pretrained('./models/gate-qwen3-lowres-en')

def print_logits(logits):
    next_token_logits = logits[-1, :]
    top_k_tokens = torch.topk(next_token_logits, 20, dim=-1).indices
    for i, token_id in enumerate(top_k_tokens):
        token_str = tokenizer.decode(token_id.item())
        print(f"Top {i+1}: {[token_str]}", end = ', ')


def get_control_toks(vocab_size):
    special_toks = []
    for vs in tokenizer.special_tokens_map.values():
        if type(vs) is str:
            assert(len(tokenizer.encode(vs)) == 1)
            special_toks.append(tokenizer.encode(vs)[0])
        else:
            for v in vs:
                assert(len(tokenizer.encode(v)) == 1)
                special_toks.append(tokenizer.encode(v)[0])
    # a bit hacky, assistant should be treat as symbol because of the <|im_end|>\n<|im_start|>assistant\n template
    special_toks.append(tokenizer.encode('assistant')[0])
    control_toks_mask = torch.zeros(vocab_size)
    print(special_toks)
    for tok in special_toks:
        control_toks_mask[tok] = 1
    control_toks_mask = torch.tensor(control_toks_mask).bool()
    return control_toks_mask

def get_lang_masks(vocab_size):
    cj_labels, latin_labels, special_labels, lowres_labels = [], [], [], []
    for tok in range(vocab_size):
        tok_text = tokenizer.decode(tok)
        cj_label = contains_cj(tok_text) or contain_partial_cjk(tokenizer, tok)
        latin_label = contains_latin(tok_text)
        special_label = not cj_label and contains_only_special(tok_text)
        cj_labels.append(cj_label)
        latin_labels.append(latin_label)
        special_labels.append(special_label)
        lowres_labels.append((cj_label == 0) and (latin_label == 0) and (special_label == 0))
    current_device = torch.cuda.current_device()
    lang_masks = torch.tensor([cj_labels, latin_labels, special_labels, lowres_labels]).to(current_device)
    print('before', lang_masks[0].sum(), lang_masks[1].sum(), lang_masks[2].sum(), lang_masks[3].sum())
    control_toks_mask = get_control_toks(vocab_size)
    lang_masks[:, control_toks_mask] = 0
    lang_masks[2, control_toks_mask] = 1
    print('after', lang_masks[0].sum(), lang_masks[1].sum(), lang_masks[2].sum(), lang_masks[3].sum())
    return lang_masks

def new_get_lang_masks(vocab_size):
    labels = []
    control_toks_mask = get_control_toks(vocab_size)
    for tok in range(vocab_size):
        tok_text = tokenizer.decode(tok)
        if control_toks_mask[tok]:
            # this should be treated as special tok
            labels.append(2)
        if contains_cj(tok_text) or contain_partial_cjk(tokenizer, tok):
            labels.append(0)
        elif contains_lang_name(tok_text, 'ARABIC'):
            labels.append(5)
        elif contains_lang_name(tok_text, 'HANGUL'):
            labels.append(6)
        elif contains_lang_name(tok_text, 'THAI'):
            labels.append(7)
        elif contains_lang_name(tok_text, 'CYRILLIC'):
            labels.append(8)
        elif contains_latin(tok_text):
            labels.append(1)
        elif '�' in tok_text:
            labels.append(4)
        elif contains_only_special(tok_text):
            labels.append(2)
        else:
            # other lowres lang
            labels.append(3)
    current_device = torch.cuda.current_device()
    lang_masks = torch.nn.functional.one_hot(torch.tensor(labels), num_classes=9)
    return lang_masks

class GatingMixin:
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config)
        print("after init")
        self.packed_modules_mapping = {
            "qkv_proj": [
                "q_proj",
                "k_proj",
                "v_proj",
            ],
            "gate_up_proj": [
                "gate_proj",
                "up_proj",
            ],
        }
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.lora_config = lora_config

        self.quant_config = quant_config

        self.code_switch_pre = ColumnParallelLinear(config.hidden_size, config.hidden_size, quant_config=quant_config, return_bias=False)
        self.code_switch_head = RowParallelLinear(config.hidden_size, 4, quant_config=quant_config, return_bias=False)
        self._lang_head = nn.Sequential(
            self.code_switch_pre,
            nn.ReLU(),
            self.code_switch_head,
        )
        lang_masks = get_lang_masks(self.logits_processor.vocab_size)
        
        assert((lang_masks.int().sum(dim=0) == 1).all())
        tok_to_lang = lang_masks.int().argmax(dim=0)

        # tok_to_lang = tok_to_lang.cpu()
        self.register_buffer(
            "tok_to_lang",
            tok_to_lang,
            persistent=False
        )

        # lang_masks = lang_masks.cpu()
        self.register_buffer(
            "lang_masks",
            lang_masks,
            persistent=False
        )

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata,
    ) -> Optional[torch.Tensor]:
        device = torch.cuda.current_device()
        batch_logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        if batch_logits is None or batch_logits.numel() == 0:
            return batch_logits
        batch_logits = batch_logits.clone()
        original_hidden_state = hidden_states.clone()
        if sampling_metadata is not None:
                hidden_states = hidden_states.index_select(0, sampling_metadata.selected_token_indices)

        ## `compute_logits` expects non-parallel code, so should use apply to ensure no parallelism
        # this is wrong `batch_lang_logits = self._lang_head(hidden_states)`
        hidden_lang = self.code_switch_pre.quant_method.apply(self.code_switch_pre, hidden_states, bias=self.code_switch_pre.bias)
        hidden_lang = nn.functional.relu(hidden_lang)
        batch_lang_logits = self.code_switch_head.quant_method.apply(self.code_switch_head, hidden_lang, bias=self.code_switch_head.bias)

        original_logits_arr = []
        top_k_arr = []
        top_p_arr = []
        last_tok_lang_arr = []
        logit_indices_arr = []
        lang_logits_arr = []
        token_ids_arr = []

        
        for seq_group in sampling_metadata.seq_groups:
            sampling_params = seq_group.sampling_params
            # Extract top_p and top_k
            top_p = sampling_params.top_p
            top_k = sampling_params.top_k
            if seq_group.do_sample: 
                seq_ids_list = seq_group.seq_ids # Get the list of seq_ids for this group
                num_seqs_in_group = len(seq_ids_list)
                for idx_within_group, logit_index_in_selected in enumerate(seq_group.sample_indices):
                    original_logits = batch_logits[logit_index_in_selected]
                    lang_logits = batch_lang_logits[logit_index_in_selected]
                    if idx_within_group < num_seqs_in_group:
                        corresponding_seq_id = seq_ids_list[idx_within_group]
                    else:
                        print(f"Warning: Index mismatch. idx_within_group={idx_within_group}, num_seqs={num_seqs_in_group}")
                        continue # Or handle error
                    seq_data: SequenceData = seq_group.seq_data[corresponding_seq_id]
                    last_token_ids = seq_data.get_token_ids()[-20:]
                    tok_to_lang = self.tok_to_lang
                    tok_langs = tok_to_lang[last_token_ids].tolist()
                    last_tok_lang = last_not_value(tok_langs, 2)
                    ## TODO: 
                    original_logits_arr.append(batch_logits[logit_index_in_selected])
                    lang_logits_arr.append(batch_lang_logits[logit_index_in_selected])
                    # Handle top_k=-1, which means consider all tokens
                    top_k_arr.append(top_k if top_k != -1 else batch_logits.shape[1])
                    top_p_arr.append(top_p)
                    token_ids_arr.append(seq_data)
                    last_tok_lang_arr.append(torch.tensor(last_tok_lang).to(device))
                    logit_indices_arr.append(logit_index_in_selected)
        # --- Vectorized Processing ---
        # 1. Convert lists to tensors for batch operations
        selected_logits = torch.stack(original_logits_arr)
        selected_lang_logits = torch.stack(lang_logits_arr)
        top_ks = torch.tensor(top_k_arr, device=device, dtype=torch.long)
        top_ps = torch.tensor(top_p_arr, device=device, dtype=torch.float)
        last_tok_langs = torch.stack(last_tok_lang_arr)
        logit_indices = torch.tensor(logit_indices_arr, device=device, dtype=torch.long)
        # logit_indices = torch.arange(len(batch_logits), device=batch_logits.device, dtype=torch.long)

        # 2. Perform batched calculations
        # Check which language categories are present in the sampling-filtered logits
        # logits_lang_labels_sampling = batch_varied_top_k_p_in_logits(
        #     selected_logits, self.lang_masks, top_ks, top_ps
        # )
        logits_lang_labels_sampling = batch_varied_top_k_p_in_logits(
            selected_logits, self.lang_masks, torch.full_like(top_ks, 20), torch.full_like(top_ps, 0.95)
        )
        # Check which language categories are present in the top-5 logits
        # It's OR between top5 and (top20, 0.95)
        logits_lang_labels_topk = batch_varied_top_k_p_in_logits(
            selected_logits, self.lang_masks, torch.full_like(top_ks, 5), torch.full_like(top_ps, 0.999)
        ) | batch_varied_top_k_p_in_logits(
            selected_logits, self.lang_masks, torch.full_like(top_ks, 20), torch.full_like(top_ps, 0.95)
        )

        
        # Get language predictions and ensure the last token's language is allowed
        pred_type = 'gate'
        if pred_type == 'gate':
            # pred with gate
            pred_lang_labels = selected_lang_logits.sigmoid() >= 0.5
            max_idx = selected_lang_logits.sigmoid().argmax(dim=-1, keepdim=True)  # Shape: [..., 1]
            best_is_true = torch.zeros_like(pred_lang_labels).scatter_(-1, max_idx, 1).bool()
            pred_lang_labels = pred_lang_labels | best_is_true
        elif pred_type == 'norm':
            selected_logits_normalized = selected_logits / self.token_norm
            pred_lang_labels = batch_varied_top_k_p_in_logits(
                selected_logits_normalized / 0.7, self.lang_masks, torch.full_like(top_ks, 20), torch.full_like(top_ps, 0.8)
            )
        else:
            # pred with lasttok
            pred_lang_labels = torch.nn.functional.one_hot(last_tok_langs, num_classes=4).bool()

        original_pred_lang_labels = pred_lang_labels.clone()
        pred_lang_labels.scatter_(dim=1, index=last_tok_langs.unsqueeze(-1), value=True)
        # no intervention if last is symbol
        mask = (last_tok_langs == 2)
        pred_lang_labels[mask, :] = True

        # Determine which languages to mask out
        langs_to_mask = ~pred_lang_labels
        langs_to_mask[:, [2, 3]] = False  # Never mask 'special' or 'low-res' categories

        
        # 3. Determine which sequences need intervention
        intervention_has_effect = (logits_lang_labels_sampling & langs_to_mask).any(dim=1)
        target_in_topk = (logits_lang_labels_topk & original_pred_lang_labels).any(dim=1)
        apply_intervention = intervention_has_effect & target_in_topk

        # check if only symbol allowed for debugging
        symbol_true_lang_labels = original_pred_lang_labels.clone()
        symbol_true_lang_labels[:, 2] = True
        only_symbol_allowed = symbol_true_lang_labels.sum(dim=-1) == 1
        
        for i in range(len(logits_lang_labels_sampling)):
            # log when has zh but last is not zh
            if (logits_lang_labels_sampling[i][0] == True and last_tok_langs[i] != 0) or (logits_lang_labels_sampling[i][1] == True and last_tok_langs[i] != 1 and last_tok_langs[i] != 1):
            # if True:
                print("!" * 30)
                print("apply intervention", apply_intervention[i])
                print([tokenizer.decode(token_ids_arr[i].get_token_ids())])
                print("intervention_has_effect", intervention_has_effect[i])
                print("target_in_topk", target_in_topk[i])
                print("logits_lang_labels_sampling", logits_lang_labels_sampling[i])
                print("logits_lang_labels_topk", logits_lang_labels_topk[i])
                print("only_symbol_allowed", only_symbol_allowed[i])
                print("original_pred_lang_labels", original_pred_lang_labels[i])
                print("selected_lang_logits", selected_lang_logits[i].sigmoid())
                print("langs_to_mask", langs_to_mask[i])
                for tok in batch_logits[logit_indices[i]].topk(k=10).indices.tolist():
                    print(tok, [tokenizer.decode(tok)], end=', ')
                print()
        
        
        # 4. Apply masks if necessary
        original_batch_logits = batch_logits.clone()
        if apply_intervention.any():
            # Get the indices in the main batch_logits tensor that we need to modify
            indices_to_modify = logit_indices[apply_intervention]
            # Get the specific logits that will be modified
            logits_to_modify = selected_logits[apply_intervention]
            # Get the language masks for the sequences needing intervention
            final_langs_to_mask = langs_to_mask[apply_intervention]

            # Create the final vocabulary mask using matrix multiplication
            # [num_interventions, C] @ [C, V] -> [num_interventions, V]
            vocab_mask = torch.matmul(final_langs_to_mask.float(), self.lang_masks.float()).bool()
            
            # Apply the mask to the original batch_logits tensor at the correct rows
            batch_logits[indices_to_modify] = logits_to_modify.masked_fill(vocab_mask, -torch.inf)

        
        for i in range(len(logit_indices)):
            if apply_intervention[i]:
            # if True:
                print("?" * 30)
                print([tokenizer.decode(token_ids_arr[i].get_token_ids())])
                # print('seq_data', token_ids_arr[i].output_token_ids_array)
                # print("sampling_metadata.categorized_sample_indices", sampling_metadata.categorized_sample_indices)
                # print('sampling_metadata.seq_groups.sample_indices', [g.sample_indices for g in sampling_metadata.seq_groups])
                # print('sampling_metadata.seq_groups.seq_ids', [g.seq_ids for g in sampling_metadata.seq_groups])
                # print('sampling_metadata.selected_token_indices', sampling_metadata.selected_token_indices)
                print('original_hidden_state', original_hidden_state.shape)
                print('hidden_states', hidden_states.shape)
                print('batch_logits', batch_logits.shape)
                print("intervention_has_effect", intervention_has_effect[i])
                print("target_in_topk", target_in_topk[i])
                print("logits_lang_labels_sampling", logits_lang_labels_sampling[i])
                print("logits_lang_labels_topk", logits_lang_labels_topk[i])
                print("only_symbol_allowed", only_symbol_allowed[i])
                print("original_pred_lang_labels", original_pred_lang_labels[i])
                print("selected_lang_logits", selected_lang_logits[i].sigmoid())
                print("langs_to_mask", langs_to_mask[i])
                print("last_tok_langs", last_tok_langs[i])
                for tok in original_batch_logits[logit_indices[apply_intervention]].topk(k=10).indices.tolist()[0]:
                    print(tok, [tokenizer.decode(tok)], end=', ')
                print()
                for tok in batch_logits[logit_indices[apply_intervention]].topk(k=10).indices.tolist()[0]:
                    print(tok, [tokenizer.decode(tok)], end=', ')
                print()
        return batch_logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loaded_params = super().load_weights(weights)
        # print("loaded_params", loaded_params)
        loader = AutoWeightsLoader(self,
                                   ignore_unexpected_prefixes=["lm_head."])
        # TODO: a bit hacky, need to fix this
        params_to_ignore = {'code_switch_head.bias', 'code_switch_pre.bias'}
        res = loader.load_weights(weights).union(loaded_params).union(params_to_ignore)

        world_size = dist.get_world_size()
        gathered_weights = [torch.empty_like(self.lm_head.weight) for _ in range(world_size)]
        dist.all_gather(gathered_weights, self.lm_head.weight)
        full_weight = torch.cat(gathered_weights, dim=0)
        self.token_norm = full_weight.norm(dim=1)
        return res

    @classmethod
    def is_backend_compatible(cls) -> bool:
        return True  # Or implement actual logic if needed


class Qwen3Gating(GatingMixin, Qwen3ForCausalLM):
    pass

class Qwen3MoeGating(GatingMixin, Qwen3MoeForCausalLM):
    pass

from vllm.model_executor.models.gpt_oss import GptOssForCausalLM
class GptOssGating(GatingMixin, GptOssForCausalLM):
    pass