from transformers import LogitsProcessor
import torch
import copy


import torch
from transformers import LogitsProcessor, BeamSearchScorer
import numpy as np



def build_concept_transition_matrix(transition_matrix_np, variants_count):
    concept_transition_matrix = []
    # add variants count for _dot_ and eos
    new_variants_count = variants_count + [1] + [1] + [1]
    # make a matrix of size #states x #concepts
    nb_concepts = len(new_variants_count)
    variants_count = np.array(new_variants_count)
    nb_states = transition_matrix_np.shape[0]
    cumulative_concepts_index = [0] * (nb_concepts)
    concept_transition_matrix = np.zeros((nb_states, nb_concepts))
    for i in range(nb_states):
        for j in range(nb_concepts):
            cumulative_concepts_index = sum(new_variants_count[:j])
            concept_transition_matrix[i][j] = transition_matrix_np[i][cumulative_concepts_index]
        
    
    return concept_transition_matrix

def complete_variants(dict_):
    """
    This function takes a dictionary of variants and returns a list of all possible combinations of the variants with spaces and capitalizations.
    """
    # add the _dot_ variants
    new_variant_dict = copy.deepcopy(dict_)
    new_variant_dict["_dot_"] = ['.']
    for key in new_variant_dict.keys():
        new_variant_dict[key] += [word.capitalize() for word in new_variant_dict[key]]
    for key in new_variant_dict.keys():
        new_variant_dict[key]+= [" " + word for word in new_variant_dict[key] if word[0] != " "]
    complete_chars = [' ',',']
    
    # remove duplicates
    for key in new_variant_dict.keys():
        new_variant_dict[key] = list(set(new_variant_dict[key]))

    return new_variant_dict

def token_match_convolution(A, B):
    
    B_size,M1 = A.shape
    V_size,M = B.shape
    assert M1 == M-1, "A should be of shape (B_size, M-1) and B should be of shape (V_size, M)"
    suffix_B = B[:, :M1]
    
    # we compute the element-wise equalities between A and B along the last dimensions

    # op_A is of shape (B_size, M1, 2*M-2), op_B is of shape (B_size, M1, 2*M-2)
    # we compute the element-wise equalities between op_A and op_B along the last dimensions 
    # obtaining a tensor of shape (B_size, V_size, M1, 2*M-2)
    op_A = torch.tile(A, (M1,1,1)).transpose(0,1)
    op_B = torch.stack([torch.roll(suffix_B,k, dims=1) for k in range(M-1)], dim=0).transpose(0,1)
    masks = torch.triu(torch.ones((M1, M1)), diagonal=0).to(A.device)

    op_A = op_A.unsqueeze(1).expand(-1, V_size, -1, -1)
    op_B = op_B.unsqueeze(0).expand(B_size, -1, -1, -1)
    masks = masks.unsqueeze(0).expand(B_size, V_size, -1, -1)
    # we compute a mask of shape (B_size, V_size, M1, 2*M-2) for element (B_size, V_size, i, j) only check the elements j with j >=  i
    #all_ones = torch.ones((B_size, V_size, M1, M1)).to(A.device)
    # we compute the element-wise equalities between op_A and op_B along the last dimension
    # s
    #all_ones = torch.ones((V_size, B_size, M1, M1),device="cuda")
    op_AB = (op_A == op_B).float()* masks + (1 - masks)

    #print(f"op_AB.shape: {op_AB.shape}")
    # we compute the element-wise equalities between op_A and op_B along the last dimension, corresponding to variants suffix matches
    AB_shift_matched = op_AB.all(dim=-1).float()

    first_pads_B = torch.argmin((B!=-1).int(), dim=1)
    # when the variatns have no padding, we set the first_pads_B to M
    first_pads_B = torch.where(first_pads_B == 0, M, first_pads_B)
    # for each beam and for each variant we compute the first index shift that gives a match
    min_indices = torch.argmin(1-AB_shift_matched, dim=-1, )
    # we look if no match is found
    AB_no_match = (1-AB_shift_matched).all(dim=-1)
    # the "completion" tokens is V[M-(k+1)] where k is the min index of the first match when match exists
    # otherwise it is the first token of the variant

    indices_of_B_to_gather = (M-(1+min_indices)) 

    indices_of_B_to_gather_f = torch.where(indices_of_B_to_gather <first_pads_B, indices_of_B_to_gather, torch.zeros_like(indices_of_B_to_gather))
    indices_of_B_to_gather_f = torch.where(AB_no_match, torch.zeros_like(indices_of_B_to_gather_f), indices_of_B_to_gather_f)
    true_completer = indices_of_B_to_gather_f==(first_pads_B-1)
    # compute real words: i.e. the first token of the variant has a space or the token before it in history has a space
    A_completer_tokens = torch.gather(B,1, indices_of_B_to_gather_f.T).T
    return A_completer_tokens, true_completer, indices_of_B_to_gather_f

class DFALogitsProcessor(LogitsProcessor):
    """
    DFA-aware logits processor with α-blending and budget pruning,
.

    dfa_layer must contain
      • 'transition_matrix' : LongTensor  [S, C]  → next_state (-1 if none)
      • 'dist'              : FloatTensor[S]     → min-steps-to-accept
      • 'variants'          : dict concept → list[str]  (surface forms)
    """

    def __init__(self, dfa_layer: dict, variants, concepts,  tokenizer, num_beams: int,
                 max_length: int, alpha: float = 0.5, gamma: float = 0.5, device='cuda', eps_favor: float = 1e-2, 
                 remove_line_jump: bool = True,
            
                 ramping: bool = True): # , forbidden_strings: list = []
        super().__init__()
        self.tk      = tokenizer
        self.num_beams   = num_beams
        self.max_length  = max_length
        self.alpha   = alpha
        self.gamma   = gamma
        self.ramping = ramping
        self.variants = variants
        self.eps_favor = eps_favor
        self.device = device
        self.remove_line_jump = remove_line_jump
        all_tokens = range(len(tokenizer))
        all_tokens = self.tk.convert_ids_to_tokens(all_tokens, skip_special_tokens=False)
        all_strings = np.array(self.tk.batch_decode(self.tk.convert_tokens_to_ids(all_tokens), skip_special_tokens=True))
   
        self.all_strings = all_strings
        space_in_token = np.array([' ' in s for s in all_strings])
        self.token_has_space = torch.from_numpy(space_in_token).to(self.device)
        line_jump_in_token = np.array(['\n' in s for s in all_strings])
        dot_in_token = np.array(['.' in s and s != '.' for s in all_strings])
        self.token_line_jump = torch.from_numpy(line_jump_in_token).to(self.device)
        self.dot_in_token = torch.from_numpy(dot_in_token).to(self.device)
        
        ends_with_space = np.array([s.endswith(" ") for s in all_strings])
        self.token_ends_with_space = torch.from_numpy(ends_with_space).to(self.device)
        
        
      
        transition_matrix_np = np.array(dfa_layer['transition_matrix'])[:,:,1]
        variants_dict = {concepts[i]: self.variants[i] for i in range(len(concepts[:-3]))}
        variants_count = [ len(variants[i]) for i in range(len(concepts[:-3])) ]
        # build transition matrix at the concept level
        transition_m_cpt= build_concept_transition_matrix(transition_matrix_np, variants_count)
        
        # complete variants
        all_variants_dict = complete_variants(variants_dict)
        
        all_variants_dict["eos"] = [tokenizer.eos_token]

        #print(tokenizer.eos_token_id)
        
        variants_tokens = [ tokenizer.encode(e,add_special_tokens=False) for l in list(all_variants_dict.values()) for e in l  ]
        
        new_variants_count = [len(all_variants_dict[concept]) for concept in all_variants_dict.keys()]

        new_variants_count.append(1)
        
        t_mc = np.concatenate([np.tile(transition_m_cpt[:,j], (new_variants_count[j],1)).T for j in range(transition_m_cpt[:,:].shape[-1])], axis=1)

        # ---------- DFA tensors ----------
        self.T_mc  = torch.as_tensor(t_mc,
                                     dtype=torch.long).to(self.device)        # [S,V]
        self.dist  = torch.as_tensor(list(dfa_layer['dist'].values())
                                     ,dtype = torch.float).to(self.device)    # [S]
        self.num_states, self.num_variants = self.T_mc.shape
        
        # ---------- pre-tokenise variants ----------
        concepts          = all_variants_dict.keys()
        self.concepts     = concepts
        
        # construct the tensor of all variants tokens ids of shape # [#variants, max_tokens_length]
        
        # make a tensor of shape [#variants, max_tokens_length] with -1 for padding
        max_len = max([len(e) for e in variants_tokens])
        self.max_m = max_len
        variants_tokens_ = [e + [-1] * (max_len - len(e)) for e in variants_tokens]
        self.var_ids = torch.tensor(variants_tokens_, dtype=torch.long)  # [V,M]
        self.var_ids = self.var_ids.to(device)
        
        self.var_has_space = torch.gather(self.token_has_space, 0, self.var_ids[:,0])  # [V,M]
        # DFA state per beam
        self.beam_states = [1]*num_beams   # start at DFA state 1 [B]
    
    
    
    
    def _compute_windows(self, input_ids: torch.Tensor) -> torch.Tensor:
        """ Compute the last `max_m - 1` tokens window for the input_ids tensor."""
        B_size, L = input_ids.shape
        dev   = input_ids.device
        pad = -1
        # compute the last `max_m` tokens window
        window_size = self.max_m - 1
        if L < window_size:
            left = torch.full((B_size, window_size- L), pad, dtype=torch.long, device=dev)
            win  = torch.cat([left, input_ids], 1)                # [B,M]
        else:
            win = input_ids[:, -(window_size):]                      # [B,M]
        # using a token convolution to compute the match between the last `max_m-1` tokens and the variants suffixes
        return win
    
    def _compute_completer_maps(self,
        completer_tokens: torch.LongTensor,   # (B, V), values in [0..T-1]
        true_completer:   torch.BoolTensor,   # (B, V)
                     # T
                                        ):
        """
        Returns:
        partial_completer_variant   BoolTensor (B, T)
        true_completer_variant      BoolTensor (B, T)
        associated_variant          LongTensor (B, T) with values in [-1..V-1]
        """
        B, V = completer_tokens.shape
        device = completer_tokens.device
        vocab_size = len(self.tk) #len(self.tk)
        # 1) Partial‐completer: any mapping, ignore completion flag
        partial = torch.zeros((B, vocab_size), dtype=torch.bool, device=device)
        # scatter a "1" for every (i,j)->t mapping
        partial.scatter_(
            dim=1,
            index=completer_tokens,               # (B,V)
            src=torch.ones((B, V), dtype=torch.bool, device=device)
        )

        # 2) True‐completer: only those that actually complete
        truec = torch.zeros((B, vocab_size), dtype=torch.bool, device=device)
        truec.scatter_(
            dim=1,
            index=completer_tokens,               # (B,V)
            src=true_completer                    # (B,V) bool src
        )

        # 3) Associated‐variant: first fill with -1
        associated = torch.full((B, vocab_size), -1, dtype=torch.long, device=device)
        # Prepare the variant‐index matrix j_idx[b,j] = j
        j_idx = torch.arange(V, device=device).unsqueeze(0).expand(B, V)

        #   a) scatter completing variants' j into associated
        #      src_full[b,j] = j if complete_token[b,j] else -1
        src_full = torch.where(
            true_completer,
            j_idx,
            torch.full_like(j_idx, -1)
        )
        associated.scatter_(
            dim=1,
            index=completer_tokens,  # (B,V)
            src=src_full              # (B,V) long
        )

        #   b) scatter *any* partial variant index j, but only
        #      where we haven't already put a completing j (i.e. assoc == -1)
        #      we do this by building a full‐partial src_p and then masking
        src_part = j_idx  # every (b,j) offers its j
        # # scatter to a fresh tensor
        part_assoc = torch.full_like(associated, -1)
        part_assoc.scatter_(
            dim=1,
            index=completer_tokens,
            src=src_part
        )
        # now fill in the gaps in `associated`
        mask_empty = (associated == -1)
        associated[mask_empty] = part_assoc[mask_empty]

        return partial, truec, associated 
        
        

    def _compute_next_states( self,
        completer_tokens:      torch.LongTensor,  # (B, V)
        true_completer:          torch.BoolTensor,  # (B, V)
        S_b:                   torch.LongTensor,  # (B)
             
    ) -> torch.LongTensor:  # (B, T)
        """
        Returns: next_states (B, T) where
        next_states[b, t] = T_mc[cur_state[b], j] for the variant j where
            completer_tokens[b, j] == t AND true_completer[b, j] == 1,
            or cur_state[b] otherwise.
        and next_quasi_states (B, T) where
        next_quasi_states[b, t] = T_mc[cur_state[b], j] for the variant j where
            completer_tokens[b, j] == t or cur_state[b] otherwise.

        If multiple variants j map to the same token t, the last one
        in j-order “wins” (scatter’s default).
        """
      
        # 1) Expand current-state to shape (B, V)
        
        cur_states_exp = S_b.unsqueeze(1).expand(-1,self.num_variants-1).to("cuda") # (B, V)
        # to match python indexing we need to substract 1 to all states
        cur_states_exp_pyt = cur_states_exp - 1
        # 2) Gather the “theoretical” next-state per (b,j):
        #    for each variant j in each batch b, what state would you go to?
        #    T_mc.gather(0, cur_exp) takes from dim=0 (states) using cur_exp[b,j]
        next_states_per_var = torch.gather(self.T_mc[:,:-1], 0, cur_states_exp_pyt)  # (B, V)

        # 3) Mask-out non-completing variants → they stay in cur_state
        #print(f"next_states_per_var.shape: {next_states_per_var.shape}")
        true_next_states_per_var = torch.where(true_completer , next_states_per_var, cur_states_exp) # (B, V)
        # 4) Initialize the final next_states and next_quasi_states matrix to cur_state
      
        next_state_others = self.T_mc[S_b-1,-1]  # (B,)      
        next_states = next_state_others.unsqueeze(1).expand(-1,len(self.tk)).to("cuda")          # (B, T)
        next_quasi_states = next_state_others.unsqueeze(1).expand(-1,len(self.tk)).to("cuda")    # (B, T)

        # 5) Scatter each next_per_var[b,j] into next_state[b, A_tokens[b,j]]
        next_states = next_states.scatter(
            dim = 1, 
            index = completer_tokens, 
            src = true_next_states_per_var)
        next_quasi_states = next_states.scatter(
            dim = 1, 
            index = completer_tokens, 
            src = true_next_states_per_var)
               
        return next_states, next_quasi_states

        
        
    def _complete_separated_word( self,
        completer_tokens:    torch.LongTensor,  # (B, V)
        prev_idx:          torch.LongTensor,  # (B, V)
        input_ids:           torch.LongTensor,  # (B, L)          # L
        variant_has_space:   torch.BoolTensor,  # (V,)

    ) -> torch.BoolTensor:                     # (B, T)
        """
        Same as before, except now we accept a 1D mask `token_ends_with_space[t]`
        which is True iff token t ends with a space.
        """
        B, V = completer_tokens.shape
        device = completer_tokens.device

        
        # 2) find the token *before* each variant-match in the beam history
    
        safe_prev_idx = prev_idx.clamp(min=0)
        prev_tok = input_ids.gather(1, safe_prev_idx)               # (B,V)
        valid_prev = torch.logical_and( prev_idx >= 0, prev_idx < input_ids.size(1)) # (B,V)

        # 3) boundary‐space check: either variant itself starts with space,
        #    or the *previous* token "ends with space" per your mask:
        vhs = variant_has_space.view(1, V).expand(B, V)             # (B,V)
        ends = self.token_ends_with_space[prev_tok]                      # (B,V)
        space_ok = vhs | (valid_prev & ends)                        # (B,V)

        # 4) combine with the “true completer” mask and scatter to (B,T)
        good =  space_ok                                     # (B,V)
        out = torch.zeros((B, len(self.tk)), dtype=torch.bool, device=device)
        out.scatter_add_(
            dim=1,
            index=completer_tokens,         # (B,V)
            src=good        # (B,V)
        )
        return out   
            
 
    # ------------------------------------------------------------------
    # HF call each decoding step
    # ------------------------------------------------------------------
    def __call__(self, input_ids: torch.LongTensor,
                 scores: torch.FloatTensor, neginf) -> torch.FloatTensor:
        
        
        B, Vocab_size = scores.shape
        device = scores.device
        dtype = scores.dtype
        
        scores = scores 
        neginf_cuda = neginf * torch.ones(1, device=device)
        length = input_ids.size(1)
        #print(f"length {length}")
        step   = input_ids.size(1)
        remain = self.max_length - step
        next_remain = self.max_length - (step + 1)

        S_b = torch.tensor(self.beam_states, device=self.device)       # [B]
        
    

        # ---------------- 1. computing next states for all tokens ----------------
        # a. compute the completer tokens and the true completer tokens
      
        window = self._compute_windows(input_ids)    
        if self.var_ids.shape[1]!=1:
            completer_tokens, full_completer, index_in_v = token_match_convolution(window,self.var_ids)  # [B, V]
            # for each variant, for each beam, we compute 
            
            # b. compute the next state for all tokens for each beam, the state changers and the quasi changers
        
            
        else:
            # if var_ids is of shape [B, 1] than there the completer tokens are the same for all beams and correpond to var_ids[0]
            completer_tokens = self.var_ids[:,0].unsqueeze(0).expand(B, -1)  # [B, V]
            # all of these tokens are true completers
            full_completer = torch.ones((B, self.var_ids.shape[0]), dtype=torch.bool, device=device)  # [B, V]
            # and all the corresponding indices are 0, so index_in_v a zero tensor of size [B,V]
            index_in_v = torch.zeros((B, self.var_ids.shape[0]), dtype=torch.long, device=device)  # [B, V]
            
        next_states, quasi_next_states = self._compute_next_states(completer_tokens, full_completer, S_b)  # [B, T]
        state_changers = next_states != S_b.unsqueeze(1).expand(-1,len(self.tk))  # [B, T]
        state_quasi_changers = quasi_next_states != S_b.unsqueeze(1).expand(-1,len(self.tk))  # [B, T] 
        # c. compute the associated distance for each beam for each token
        
        dist_exp = self.dist.unsqueeze(0).expand(len(self.tk),-1).to("cuda")
        next_states_pyt = next_states - 1
        quasi_next_states_pyt = next_states - 1
        next_dist = torch.gather(dist_exp.T, 0, next_states_pyt)
        quasi_next_dist = torch.gather(dist_exp.T, 0, quasi_next_states_pyt)
        del quasi_next_states_pyt
        del next_states_pyt
        # ---------------- 2. removing the tokens that make the LTLf formula impossible to satisfy ----------------
        # a. remove the tokens whose distance is greater than the next remaining number of tokens
        
        neginf_cuda = neginf * torch.ones(1, device=device)
        to_remove = next_dist > next_remain
        
        scores = torch.where(to_remove, neginf_cuda.expand_as(scores), scores)
        
        # ---------------- 3.  We promote only the tokens that lead to a state closer to an accepting state
        #partial_completer_variant, true_completer_variant, associated_variant = self._compute_completer_maps(completer_tokens, full_completer, S_b)
        # computing tensor of shape [B, T] with 1 for the tokens that are completing a variant with a space
        # and 0 otherwise
        
        dist_cur = self.dist[S_b-1].unsqueeze(1).expand(-1,len(self.tk))      # [B, T]
        closer   = next_dist < dist_cur # [B, T]   
        quasi_closer = quasi_next_dist < dist_cur # [B, T] 
        #as_good = next_dist == dist_cur
        
        # Additionally we look for variants completion that either correspond to "clean words":
        # in the sense that there is a space in the first token of the variant or the token before it in history has a space
        #variant_completed_has_space = torch.gather(self.var_has_space.unsqueeze(0), 1, completer_tokens)  # [B, V]
        
        has_token_in_history = index_in_v != 0
        prev_idx = torch.where(has_token_in_history, length- index_in_v, -torch.ones_like(index_in_v))   
        
        separated_word_completion = self._complete_separated_word( 
            completer_tokens, prev_idx, input_ids, self.var_has_space
        )
        
        
        good_changers = state_changers & closer & separated_word_completion
        good_quasi_changers = state_quasi_changers & quasi_closer & separated_word_completion
        not_neg_inf = scores > neginf
        #index_low_score = (scores <= neginf).any().to(torch.int64)
        false_scores = torch.where(not_neg_inf,scores, scores - 1.5*neginf)
  
        second_minimum = false_scores.min(dim=-1).values
   
        mod_scores = scores - second_minimum.unsqueeze(1).expand_as(scores) + self.eps_favor
        apply_blend = good_changers | good_quasi_changers
        if apply_blend.any():
            if self.ramping:
                factor  = self.dist[S_b-1] / max(1, remain)
                alpha_k = self.alpha + (1 - self.alpha) * factor.pow(self.gamma)
            else:
                alpha_k = scores.new_full((B,), self.alpha)
            alpha_k = alpha_k.unsqueeze(1)
            blended = alpha_k * mod_scores.max(dim=-1, keepdim=True).values + \
                      (1 - alpha_k + (alpha_k**2)*self.eps_favor) * mod_scores
            blended_true_changers = blended + (alpha_k**2)*self.eps_favor*mod_scores
            mod_scores = torch.where(apply_blend & good_quasi_changers, blended, mod_scores)
            mod_scores = torch.where(apply_blend & good_changers, blended_true_changers, mod_scores)

        scores = mod_scores + second_minimum.unsqueeze(1).expand_as(scores) -self.eps_favor
        
        if self.remove_line_jump:
            # put to -inf the scores of tokens that contain a line jump for the next token selcetion
            scores = torch.where(self.token_line_jump.unsqueeze(0).expand(B, len(self.tk)), neginf_cuda.expand_as(scores), scores)
            scores = torch.where(self.dot_in_token.unsqueeze(0).expand(B, len(self.tk)), neginf_cuda.expand_as(scores), scores)
            
        
        self._last_input_ids  = input_ids     # (B, L) with the *new* past tokens
        self._last_next_states = next_states  # (B, T)
        

        return scores

    # ------------------------------------------------------------------
    # Keep DFA state list in sync with HF beam re-ordering
    # ------------------------------------------------------------------
    def update_beam_states(self,  beam_idx: torch.LongTensor, next_tokens: torch.LongTensor):
        # 1) reorder your saved matrices to match HF’s new beam ordering
        #self._last_input_ids   = self._last_input_ids[beam_idx]
        self._last_next_states = self._last_next_states[beam_idx]

        # 2) grab the token that *just* got appended to each beam
        #    it’s now the last column of the reordered input_ids
        

        # 3) pick out the DFA state from next_states[b, chosen_tokens[b]]
        #    that *is* your new beam state
        new_states = self._last_next_states[
            torch.arange(next_tokens.size(0), device=next_tokens.device),
            next_tokens
        ]                                                 # shape (B,)

        # 4) stash it back into your Python list
        self.beam_states = new_states.tolist()





class DFABeamSearchScorer(BeamSearchScorer):
    def __init__(self, dfa_processor: DFALogitsProcessor,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dfa_processor = dfa_processor

    def process(
        self,
        input_ids,
        next_scores,
        next_tokens,
        next_indices,
        
        
        **kwargs
    ):
        # this calls the original beam‐pruning & reordering
        out = super().process(input_ids,next_scores, next_tokens, next_indices, **kwargs)
        # process returns a NamedTuple whose .next_beam_indices is the reorder map
        # new_beam_indices = out.next_indices
        # # now update your DFA states
        # self.logits_processor.update_beam_states(new_beam_indices)
        return out