from src.graph_constrained_decoding import GraphConstrainedDecoding
from .base_hf_causal_model import HfCausalModel
from transformers import StoppingCriteriaList
from anchoring import SPALogitsProcessor, spa_tokenize

class GraphConstrainedDecodingModel(HfCausalModel):
    def __init__(self, args):
        super().__init__(args)
    
    def generate_sentence(self, question, llm_input, ground_truth_idx, scores, trie, start_token_ids = None, end_token_ids = None, enable_constrained_by_default = True, reason_errors_LLM = None):
        # inputs = self.tokenizer(llm_input, return_tensors="pt", add_special_tokens=False)

        global_anchors = []
        for gt_idx in ground_truth_idx:
            start, end = gt_idx
            global_anchors.append(llm_input[start:end])

        inputs, aux_inputs, mask_token = spa_tokenize(
            prompt_with_anchors=llm_input,
            global_anchors=global_anchors,
            tokenizer=self.tokenizer,
            device=self.model.device
        )

        strength_value = 2.0

        aux_inputs_ids = aux_inputs.input_ids.to(self.model.device)
        spa_processor = SPALogitsProcessor(
            aux_model=self.model, 
            aux_input_ids=aux_inputs_ids, 
            strength=strength_value,
            modulated_by_prob=False,
            use_attention_mask=True,
            mask_token=mask_token,
            tokenizer=self.tokenizer
        )

        input_ids = inputs.input_ids.to(self.model.device)
        attention_mask = inputs.attention_mask.to(self.model.device)
        gcr = GraphConstrainedDecoding(self.tokenizer, question, trie, start_token_ids, end_token_ids, enable_constrained_by_default, reason_errors_LLM)
        
        def check_repetitive_eot_stopping(input_ids, scores, repetition_threshold=10):
            end_tokens = {
                "<|eot_id|>": self.tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0],
                "<|im_end|>": self.tokenizer.encode("<|im_end|>", add_special_tokens=False)[0],
                "</s>": self.tokenizer.encode("</s>", add_special_tokens=False)[0]
            }
            
            for batch_idx in range(input_ids.shape[0]):
                sequence = input_ids[batch_idx]
                
                if len(sequence) < repetition_threshold:
                    continue
                    
                last_n_tokens = sequence[-repetition_threshold:]
                
                for token_name, token_id in end_tokens.items():
                    if all(token == token_id for token in last_n_tokens):
                        return True
            
            return False
        
        stopping_criteria = StoppingCriteriaList([check_repetitive_eot_stopping])

        try:
            res = self.model.generate(
                input_ids = input_ids,
                attention_mask = attention_mask,
                generation_config=self.generation_cfg,
                prefix_allowed_tokens_fn=gcr.allowed_tokens_fn,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.eos_token_id,
                renormalize_logits=True,
                early_stopping="never", 
                stopping_criteria=stopping_criteria,
                # eos_token_id=[self.tokenizer.eos_token_id,self.tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]],
                logits_processor=[spa_processor],
            )
        except Exception as e:
            print(e)
            return None
        response = []
        if len(res.sequences) == 1:
            return self.tokenizer.decode(res.sequences[0][input_ids.shape[1]:],skip_special_tokens=True)
        for r in res.sequences:
            response.append(self.tokenizer.decode(r[input_ids.shape[1]:], 
          skip_special_tokens=True))
        return response
        
