import torch
from transformers import LogitsProcessor, NoBadWordsLogitsProcessor
from transformers import StoppingCriteria, StoppingCriteriaList
import re
from copy import deepcopy


def construct_token_lists(
    tokenizer,
    regex_contains: list = [],
    regex_fullmatch: list = [],
    any_newlines: bool = False,         # workaround for the yamlism
    any_double_newlines: bool = False,
    added_ids_list = [],
):
    # make several passes over the vocabulary and remove stuff that matches specified criteria
    vocab_subset = set({})

    if any_newlines:
        regex_contains.append(r'\n')
        regex_contains.append(r'Ċ')
        regex_fullmatch.append(r'\n')
        regex_fullmatch.append(r'Ċ')
    elif any_double_newlines:
        regex_contains.append(r'\n\n')
        regex_contains.append(r'ĊĊ') # Ċ for BPE
        regex_fullmatch.append(r'\n\n')
        regex_fullmatch.append(r'ĊĊ')

    # try to match the whole vocabulary
    for rs_str in regex_contains:
        rs = re.compile(rs_str)
        for t, tokid in tokenizer.vocab.items():
            if re.search(rs, t) is not None or rs_str in t or t == rs_str:
                vocab_subset.add(tokid)

    for rs_str in regex_fullmatch:
        rs = re.compile(rs_str)
        for t, tokid in tokenizer.vocab.items():
            if re.fullmatch(rs, t) is not None or rs_str in t or t == rs_str:
                vocab_subset.add(tokid)

    return list(vocab_subset)+added_ids_list



class ExtraEOSTokenLogitsProcessorWithConstructor(LogitsProcessor):
    def __init__(
        self, 
        tokenizer,
        constructor: callable,
        eos_token_id: int,
        **kwargs_unused
    ):
        self.considered_tokens = constructor(tokenizer=tokenizer)
        print(f"Number of extra stop tokens: {len(self.considered_tokens)}")
        # print(sorted(self.considered_tokens))
        assert eos_token_id not in self.considered_tokens, "Adding eos token id to considered tokens kills the crab."
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if len(self.considered_tokens) > 0:
            # do the log sum exp to combine the probs of the required tokens into the eos
            scores[:, self.eos_token_id] = torch.logsumexp(scores[:, self.considered_tokens + [self.eos_token_id]], dim=1)
            # assign the unwanted tokens to be -inf
            scores[:, self.considered_tokens] = -torch.inf
        return scores


class MaskingTokenLogitsProcessorWithConstructor(LogitsProcessor):
    def __init__(
        self, 
        tokenizer,
        constructor: callable,
        eos_token_id: int,
        **kwargs_unused
    ):
        # TODO: optionally implement the CoT compatibility
        self.considered_tokens = constructor(tokenizer=tokenizer)
        print(f"Number of masked tokens: {len(self.considered_tokens)}")
        print(sorted(self.considered_tokens))
        assert eos_token_id not in self.considered_tokens, "Adding eos token id to considered tokens kills the crab."
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if len(self.considered_tokens) > 0:
            # assign the unwanted tokens to be -inf
            scores[:, self.considered_tokens] = -torch.inf
        return scores


class StopOnCharCriteria(StoppingCriteria):
    def __init__(self, stop_on_ids = []):
        super(StopOnCharCriteria, self).__init__()
        self.stop_on_ids = stop_on_ids

    def __call__(self, input_ids, scores):
        base_bools = torch.zeros((input_ids.shape[0]), dtype=torch.bool, device=input_ids.device)
        for sid in self.stop_on_ids:
            mask = (input_ids[:, -1] == sid)
            base_bools = base_bools | mask
        return base_bools


def prepare_critlist(*criterias):
    return StoppingCriteriaList(criterias)


@torch.no_grad()
def get_generation_length_ids_from_batch(sequences, eos_token_id):
    # this function returns length of each generated sequence in the batch
    tensor_len = torch.zeros(sequences.shape[0], dtype=torch.long)
    sequences = sequences.cpu()
    for i in range(sequences.shape[0]):
        for j in range(sequences.shape[1]):
            if sequences[i, j] == eos_token_id:
                tensor_len[i] = j
                break
            elif j == sequences.shape[1] - 1:
                # did not find anything, assign the length of the generation
                tensor_len[i] = j #sequences.shape[1]
    return tensor_len


class ExtraEOSTokenLogitsProcessor(LogitsProcessor):
    def __init__(
        self, 
        considered_tokens: list, 
        eos_token_id: int,
    ):
        # TODO: optionally implement the CoT compatibility
        assert eos_token_id not in considered_tokens, "Adding eos token id to considered tokens kills the crab."
        self.considered_tokens = list(set(considered_tokens))
        self.eos_token_id = eos_token_id

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if len(self.considered_tokens) > 0:
            # do the log sum exp to combine the probs of the required tokens into the eos
            scores[:, self.eos_token_id] = torch.logsumexp(scores[:, self.considered_tokens + [self.eos_token_id]], dim=1)
            # assign the unwanted tokens to be -inf
            scores[:, self.considered_tokens] = -torch.inf
        return scores


def probe_token_id(tokenstr, tokenizer):
    return tokenizer.encode('a'+tokenstr, add_special_tokens=False)[-1]


@torch.no_grad()
def reassign_ids(ids, considered_ids, target_token, starting_from=0):
    # reassign the token ids, hope to fix the beam search length stuff this way
    if len(considered_ids) == 0:
        return ids # nothing to replace, don't try
    ids = ids.detach().clone() # make a copy, don't want to mess things up
    for ct in considered_ids:
        pos_indicator = (torch.ones_like(ids).cumsum(-1)-1)>starting_from
        ids[(ids==ct)&pos_indicator] = target_token
    return ids


###
### The following two are for VLM batching, hope they work not just for phi35_vi
### https://gist.github.com/tomasruizt/21cfd764f8d89a7802bf32537af55bbe
###
from transformers import BatchFeature


def stack_and_pad_inputs(inputs: list[BatchFeature], pad_token_id: int, device='cpu') -> BatchFeature:
    listof_input_ids = [i.input_ids[0] for i in inputs]
    new_input_ids = pad_left(listof_input_ids, pad_token_id=pad_token_id)
    data = dict(
        pixel_values=torch.cat([i.pixel_values for i in inputs], dim=0),
        image_sizes=torch.cat([i.image_sizes for i in inputs], dim=0),
        input_ids=new_input_ids,
        attention_mask=(new_input_ids != pad_token_id).long(),
    )
    new_inputs = BatchFeature(data).to(device)
    return new_inputs


def pad_left(seqs: list[torch.Tensor], pad_token_id: int) -> torch.Tensor:
    """Example: pad_left([[1, 2], [3, 4, 5]], pad_token_id=0) -> [[0, 1, 2], [3, 4, 5]]"""
    max_len = max(len(seq) for seq in seqs)
    padded = torch.full((len(seqs), max_len), pad_token_id)
    for i, seq in enumerate(seqs):
        padded[i, -len(seq) :] = seq
    return padded


@torch.no_grad()
def evaluate_sequences_transition_probs(
    model, # CausalLMModel
    sequences, # [batch_size, max_seq_len]
    # sequences_len, # [batch_size]
    prompt, # [1, prompt_len]
    stop_on_tokens=None, # list[int]
    logits_mode="logits", # [logits, scores]
):
    device = next(model.parameters()).device
    mouts = model(input_ids=sequences.to(device))
    # recompute transition probabilities manually for the beam search 
    transition_scores = torch.zeros(mouts.logits.shape[:2])[:, :-1] - torch.inf
    logits = mouts.logits.log_softmax(dim=-1) # turn into log probs
    # assign the same added up probability to all stop tokens (fine to do here, since the sampling is not done, just reconstruction)
    logits[:, :,  stop_on_tokens + [model.config.eos_token_id]] = torch.logsumexp(logits[:, :, stop_on_tokens + [model.config.eos_token_id]], dim=-1).unsqueeze(-1)
    
    # output_lens = records[fix_part][data_idx]['sequences_len']
    # tok_x_nlls = records[fix_part][data_idx]['tok_x_nlls']
    # print(transition_scores.shape, logits.shape, mins.shape)

    for bid in range(transition_scores.shape[0]):
        for tokid in range(transition_scores.shape[1]):
            # print(bid,tokid,mins[bid, tokid].item())
            transition_scores[bid, tokid] = logits[bid, tokid, sequences[bid, tokid+1].item()] 
            #if tokid <= output_lens[bid]+tok_x_nlls.shape[0] else -torch.inf
    # update the transition probs
    return transition_scores[:, prompt.shape[-1]-1:]



@torch.no_grad()
def compute_in_cot_out_ranges(sequences, think_tok_start, think_tok_end, eos_token_id):
    # this function returns length of each generated sequence in the batch
    input_ranges = torch.zeros((sequences.shape[0], 2), dtype=torch.long)
    cot_ranges = torch.zeros((sequences.shape[0], 2), dtype=torch.long)
    answer_ranges = torch.zeros((sequences.shape[0], 2), dtype=torch.long)

    tensor_len = torch.zeros(sequences.shape[0], dtype=torch.long)
    sequences = sequences.cpu()
    for i in range(sequences.shape[0]):
        # the last cot and the most input possible is taken
        # cot_on = False
        # input_on = True
        # answer_on = False
        for j in range(sequences.shape[1]):
            if sequences[i, j] == think_tok_start:
                input_ranges[i, 1] = j
                cot_ranges[i, 0] = j+1
                # input_on = False
                # cot_on = True
            elif sequences[i, j] == think_tok_start:
                cot_ranges[i, 1] = j
                answer_ranges[i, 0] = j+1
                # answer_on = True
                # cot_on = False
            elif sequences[i, j] == eos_token_id or j == sequences.shape[1] - 1:
                answer_ranges[i,1] = j
                tensor_len[i] = j
                break
            
    return tensor_len, (input_ranges, cot_ranges, answer_ranges)


from typing import Dict, Tuple, List
from collections import defaultdict


@torch.no_grad()
def mask_out_scores(
    scores,
    ids: list,
    retain: bool = False
):
    # retain == True - leave the ids in, otherwise ids are turned to -inf
    # force the model to use one of the ids from the list
    assert scores.ndim == 1, "No batched stuff here!"
    if retain:
        mask = torch.ones_like(scores)
        mask[ids] = 0.
    else:
        mask = torch.zeros_like(scores)
        mask[ids] = 1.
    scores[mask>0.5] = -torch.inf
    return scores


@torch.no_grad()
def reallocate_scores(
    scores,
    idx_from: list,
    idx_to: list,
):
    # reallocate the logit values without messing up the normalization
    assert scores.ndim == 1, "No batched stuff here!"
    assert isinstance(idx_from, list) and isinstance(idx_from, list), "idx_from and idx_to must be lists!"
    if len(idx_to) > 1:
        # homogenize
        logsum = torch.logsumexp(scores[idx_from], dim=0)
        logsum = logsum/torch.log(
            torch.tensor([len(idx_to)], device=logsum.device)
        )
        # distribute (more complicated)
        for idx in idx_to:
            # now must logsumexp with the recipient
            scores[idx] = torch.logsumexp(torch.tensor([scores[idx], ]), dim=0)
        # set the from scores to -inf
        scores[idx_from] = -torch.inf
    else:
        # homogenize
        # print(idx_from+idx_to)
        logsum = torch.logsumexp(scores[idx_from+idx_to], dim=0)
        # distribute
        scores[idx_to] = logsum
        # set the from scores to -inf
        scores[idx_from] = -torch.inf
    return scores



class LogitsProcessorWithState(LogitsProcessor):
    def reset(self):
        # resets the state
        return NotImplemented

    
try:
    import xgrammar as xgr


    class xgrGuidedGeneratorStandard(LogitsProcessorWithState):
        def __init__(
            self,
            tokenizer,
            grammar_string,
            vocab_size,
            is_ebnf_provided = False, # defaults to regex
            root_rule_name = 'taggedroot'
        ):
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size)
            # tokenizer_info.special_token_ids
            grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
            if not is_ebnf_provided:
                self.compiled_grammar = grammar_compiler.compile_regex(grammar_string)
            else:
                self.compiled_grammar = grammar_compiler.compile_grammar(grammar_string, root_rule_name=root_rule_name)
            self.matchers = None
            self.token_bitmasks = None
            self.vocab_size = vocab_size

        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            if self.matchers is None or self.token_bitmasks is None:
                self.matchers = [xgr.GrammarMatcher(self.compiled_grammar) for _ in range(input_ids.shape[0])]
                self.token_bitmasks = xgr.allocate_token_bitmask(input_ids.shape[0], self.vocab_size)
            else:
                # check if we are first time around
                # update all the relevant matchers
                for matcher, tok in zip(self.matchers, input_ids[:, -1]):
                    if not matcher.is_terminated():
                        # print(tok)
                        assert matcher.accept_token(tok, debug_print=False)

            # update all the scores of interest after resetting the mask
            xgr.reset_token_bitmask(self.token_bitmasks)        
            for i, matcher in enumerate(self.matchers):
                if not matcher.is_terminated():
                    matcher.fill_next_token_bitmask(self.token_bitmasks, i)

            # We only support masking logits on CUDA or CPU
            scores_dev = scores.device
            scores = scores.cpu()
            xgr.apply_token_bitmask_inplace(scores, self.token_bitmasks.to(scores.device))
            # print(torch.topk(scores, 10))       

            scores = scores.to(scores_dev)
            return scores

        def reset(self) -> bool:
            self.matchers = None
            self.token_bitmasks = None
            self.token_bitmasks = None
            return True


    class xgrGuidedGeneratorPlusThinking(LogitsProcessorWithState):
        def __init__(
            self,
            tokenizer,
            grammar_string,
            vocab_size,
            cot_start_id,
            cot_end_id,
            eos_ids,
            max_cot_tokens,
            is_ebnf_provided = False, # defaults to regex
            root_rule_name = 'taggedroot'
        ):
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size)
            # tokenizer_info.special_token_ids
            grammar_compiler = xgr.GrammarCompiler(tokenizer_info)
            if not is_ebnf_provided:
                self.compiled_grammar = grammar_compiler.compile_regex(grammar_string)
            else:
                self.compiled_grammar = grammar_compiler.compile_grammar(grammar_string, root_rule_name=root_rule_name)
            self.matchers = None
            self.token_bitmasks = None
            self.vocab_size = vocab_size

            # thinking part
            self.cot_start_id = cot_start_id
            self.cot_end_id = cot_end_id
            self.eos_ids = eos_ids
            self.max_cot_tokens = max_cot_tokens
            self.thinking_states = []
            self.cot_lens = []


        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            if self.matchers is None or self.token_bitmasks is None:
                self.matchers = [xgr.GrammarMatcher(self.compiled_grammar) for _ in range(input_ids.shape[0])]
                self.token_bitmasks = xgr.allocate_token_bitmask(input_ids.shape[0], self.vocab_size)
                self.thinking_states = [0]*input_ids.shape[0]
                self.cot_lens = [0]*input_ids.shape[0]
            
            # reset all masks
            xgr.reset_token_bitmask(self.token_bitmasks) 
            
            # for each sequence individually
            for i, (matcher, tok) in enumerate(zip(self.matchers, input_ids[:, -1])):
                if self.thinking_states[i]==0:
                    # initiate start sequence
                    scores[i] = mask_out_scores(scores[i], self.cot_start_id, retain=True)
                    self.thinking_states[i] = 1

                elif self.thinking_states[i]==1:
                    self.cot_lens[i] += 1
                    # check for length condition
                    if self.cot_lens[i] >= self.max_cot_tokens:
                        self.thinking_states[i] = 2
                    else:
                        # check that it hasn't been completed by the model itself
                        if tok == self.cot_end_id:
                            # jump to regex directly
                            self.thinking_states[i] = 3
                        else:
                            # reallocate the eos tokens logits to eot token to continue cot generation
                            scores[i] = reallocate_scores(scores[i], self.eos_ids, [self.cot_end_id])
                
                # can be started off immediately as the cot is completed
                if self.thinking_states[i]==2:
                    # finish cot
                    scores[i] = mask_out_scores(scores[i], self.cot_end_id, retain=True)
                    self.thinking_states[i] = 3

                elif self.thinking_states[i] >=3:
                    # apply the regex matcher
                    # check if we are first time around
                    # update all the relevant matchers
                    if self.thinking_states[i]==3:
                        print('Starting guided: ', self.cot_lens[i])
                        self.thinking_states[i] = 4
                    elif not matcher.is_terminated():
                        assert matcher.accept_token(tok, debug_print=False)
                        # print(matcher.find_jump_forward_string())
                    # fill the mask for the given token
                    if not matcher.is_terminated():
                        matcher.fill_next_token_bitmask(self.token_bitmasks, i)

            # cuda seems a bit broken even though they use torch :/
            scores_dev = scores.device
            scores = scores.cpu()
            xgr.apply_token_bitmask_inplace(scores, self.token_bitmasks.to(scores.device))
            # print(torch.topk(scores, 10))
            scores = scores.to(scores_dev)
            return scores

        def reset(self) -> bool:
            self.matchers = None
            self.token_bitmasks = None
            self.token_bitmasks = None
            return True

    from types import SimpleNamespace

    class WaitOnLogitPseudoGrammar(object):
        def __init__(self, release_token, max_duration=-1, avoid_eos=-1):
            self.release_token = release_token
            self.max_duration = max_duration
            self.avoid_eos = avoid_eos

            self.cur_duration = 0
            self.released = False

        def reset(self):
            self.cur_duration = 0
            self.released = False

        def fill_next_token_bitmask(self, bitmask, batch_id):
            if self.avoid_eos >= 0:
                # the final bitmask is just a torch bit array
                # with the original int32 type, -1 is the full mask
                eos_byte = self.avoid_eos // 32
                eos_bit = self.avoid_eos % 32
                eos_bitmask = ~(1 << eos_bit)
                # set the correct bit to 0
                bitmask[batch_id, eos_byte] = bitmask[batch_id, eos_byte] & eos_bitmask
                return 
            else:
                return

        def accept_token(self, token_id, debug_print=False):
            if token_id == self.release_token or (self.cur_duration>=self.max_duration and self.max_duration > 0):
                self.released = True
            else:
                self.cur_duration += 1
            return True

        def is_terminated(self):
            return self.released


    def create_extended_matcher(grammar, **kwargs):
        # adapter to also be able to create some fake grammars
        if isinstance(grammar, xgr.CompiledGrammar):
            return xgr.GrammarMatcher(grammar, **kwargs)
        elif isinstance(grammar, WaitOnLogitPseudoGrammar):
            retmatcher = deepcopy(grammar)
            retmatcher.reset()
            return retmatcher


    class xgrAdaptiveTriggeredGuide(LogitsProcessorWithState):
        def __init__(
            self,
            tokenizer,
            vocab_size,
            grammars_and_triggers: dict,
            n_backpeek_triggers: int = 10,
            starting_state: int = 0, # allows to start with a specific grammar matcher
            starting_state_by_trigger: str = None, # set the starting set by trigger
            min_updates_in: int = 1, # allows looking into the prompt x tokens back
        ):
            # prepare basics
            self.tokenizer = tokenizer
            self.vocab_size = vocab_size
            tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer, vocab_size=vocab_size)
            grammar_compiler = xgr.GrammarCompiler(tokenizer_info)

            # process the triggers and grammars
            self.grammars = {}
            self.triggers = {}
            self.trigger_states = {}
            self.states_trigger = {}

            # go through the dict
            for i, (trigger_name, trigger_conf) in enumerate(grammars_and_triggers.items()):
                if 'is_ebnf_provided' not in trigger_conf.keys():
                    trigger_conf['is_ebnf_provided'] = False
                if 'is_stag_provided' not in trigger_conf.keys():
                    trigger_conf['is_stag_provided'] = False
                if 'is_clogic_provided' not in trigger_conf.keys():
                    trigger_conf['is_clogic_provided'] = False
                trigger_conf = SimpleNamespace(**trigger_conf)
                # trigger conf: trigger_seq, is_ebnf_provided, grammar_string, root_rule_name
                # compile the grammar
                if trigger_conf.is_ebnf_provided:
                    compiled_grammar = grammar_compiler.compile_grammar(trigger_conf.grammar_string, root_rule_name=trigger_conf.root_rule_name)
                elif trigger_conf.is_stag_provided:
                    raise NotImplementedError('Need to implement structural tags')
                elif trigger_conf.is_clogic_provided:
                    compiled_grammar = WaitOnLogitPseudoGrammar(
                        release_token=trigger_conf.release_token,
                        max_duration=trigger_conf.max_duration,
                        avoid_eos=trigger_conf.avoid_eos,
                    )
                else:
                    compiled_grammar = grammar_compiler.compile_regex(trigger_conf.grammar_string)
                
                self.grammars[trigger_name] = compiled_grammar # grammar -> grammar shite
                self.triggers[trigger_conf.trigger_seq] = trigger_name  # trigger sequence -> grammar
                self.trigger_states[trigger_name] = i+1 # grammar name -> state index
                self.states_trigger[i+1] = trigger_name # state index -> grammar name
            
            # prepare states
            self.matchers = None
            self.token_bitmasks = None
            self.states = None
            self.updates_in = min_updates_in

            self.min_updates_in = min_updates_in
            self.n_backpeek_triggers = n_backpeek_triggers
            if starting_state_by_trigger is None:
                self.starting_state = starting_state
            else:
                # set based on trigger
                self.starting_state = self.trigger_states[self.triggers[starting_state_by_trigger]]


        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            print('ping')
            just_initialized = False
            if self.matchers is None or self.token_bitmasks is None or self.states is None:
                if self.starting_state!=0:
                    self.states = [self.starting_state,]*input_ids.shape[0]
                    self.matchers = [create_extended_matcher(self.grammars[self.states_trigger[self.starting_state]]) for i in range(input_ids.shape[0])]
                    just_initialized = True
                else:
                    self.states = [0,]*input_ids.shape[0]
                    self.matchers = [None,]*input_ids.shape[0]
                self.token_bitmasks = xgr.allocate_token_bitmask(input_ids.shape[0], self.vocab_size)
                # print(self.states, self.updates_in)
            
            # reset all masks
            xgr.reset_token_bitmask(self.token_bitmasks) 
            
            # for each sequence check the last k tokens and state
            top_peek = input_ids[:, -min(input_ids.shape[0], self.n_backpeek_triggers, self.updates_in):].cpu()
            if not just_initialized:
                for i in range(input_ids.shape[0]):
                    if self.states[i] == 0:
                        # check if we have triggered anything
                        last_seq = self.tokenizer.decode(top_peek[i])
                        for trigger in self.triggers.keys():
                            if trigger in last_seq:
                                # activate the trigger by updating state
                                self.states[i] = self.trigger_states[self.triggers[trigger]]
                                # initialize the matcher for the triggered grammar
                                self.matchers[i] = create_extended_matcher(self.grammars[self.triggers[trigger]])
                                # apply the first mask
                                self.matchers[i].fill_next_token_bitmask(self.token_bitmasks, i)
                                assert not self.matchers[i].is_terminated(), "Matcher terminated on arrival! Bad!"
                                print(self.states, self.updates_in)
                                break # just in case: avoid triggering multiple grammars!
                    
                    elif self.states[i] != 0:
                        # put the next token through the matcher
                        assert self.matchers[i].accept_token(top_peek[i, -1], debug_print=False)
                        # fill the mask or roll back the state if terminated
                        if not self.matchers[i].is_terminated():
                            self.matchers[i].fill_next_token_bitmask(self.token_bitmasks, i)
                        else:
                            self.states[i] = 0
                            self.matchers[i] = None
                            print(self.states, self.updates_in)
                                      
            else:
                for i in range(input_ids.shape[0]):
                    self.matchers[i].fill_next_token_bitmask(self.token_bitmasks, i)

            # cuda seems a bit broken even though they use torch :/
            scores_dev = scores.device
            scores = scores.cpu()
            xgr.apply_token_bitmask_inplace(scores, self.token_bitmasks.to(scores.device))
            scores = scores.to(scores_dev)

            self.updates_in += 1  # to make sure we do not peek too much into prompt
            return scores

        def reset(self) -> bool:
            self.matchers = None
            self.token_bitmasks = None
            self.states = None
            self.updates_in = self.min_updates_in
            return True



except ImportError as e:
    print(e)
    print("Failed to import xgrammar, proceeding without the corresponding logits processors!")
