from typing import List
import torch

CHECK_REASON_ERRORS_TEMPLATE = '''
[INST]
<<SYS>> You are a linguist with excellent semantic understanding and reasoning abilities. <</SYS>>
{instruction} You will complete the following task: Given a question and a known correct reasoning path (path1), path2 is a current single-step reasoning that continues from the known reasoning path path1. You need to determine whether the current single-step reasoning path2 is relevant to the question (helpful for answering the question). If path2 is relevant to the question, please provide a "[Yes]" answer; otherwise, if path2 is irrelevant to the question, provide a "[No]" answer.
You need to complete the task following these steps:
Step 1. Thoroughly understand the semantics of the question and path1
Step 2. Based on the given question and path1, conduct a step-by-step analysis of "Is the current single-step reasoning path2 helpful for answering the question?"
Step 3. Provide your judgment
{/instruction}
### Your Turn
question: 
{question}
path1: 
{path1}
path2: 
{path2}
Your judgment: 
[Yes]/[No]
[/INST]
'''

class GraphConstrainedDecoding:
    def __init__(self, tokenizer, question, trie, start_token_ids = None, end_token_ids = None, enable_constrained_by_default = False, reason_errors_LLM = None):
        self.reason_errors_LLM = reason_errors_LLM
        self.reason_errors_LLM.prepare_for_inference()

        self.question = question
        self.tokenizer = tokenizer
        self.trie = trie
        self.start_token = start_token_ids
        self.end_token = end_token_ids
        self.all_tokens = list(range(len(tokenizer)))
        self.constrained_flag = enable_constrained_by_default
        self.L_input = None

    def check_reason_errors(self, decoded_text: str):
        decoded_text = decoded_text.strip()
        parts = decoded_text.rsplit("->", 2)
        path1 = parts[0].strip()
        path2 = "->".join(p.strip() for p in parts[1:]).strip()
        question = self.question.strip()

        llm_input = (CHECK_REASON_ERRORS_TEMPLATE
             .replace("{question}", question)
             .replace("{path1}", path1)
             .replace("{path2}", path2))
    
        judgement = self.reason_errors_LLM.generate_sentence(llm_input).strip()
        
        if "[yes]" in judgement.lower():
            judgement = True
        elif "[no]" in judgement.lower():
            judgement = False
        return judgement

    def check_constrained_flag(self, sent: torch.Tensor):
        # Check start
        s = self.tokenizer.decode(self.start_token)
        e = self.tokenizer.decode(self.end_token)
        matched_start_token = torch.where(sent == self.start_token)[0]
        if len(matched_start_token) == 0:
            return False, len(sent)
        last_start_tokens = torch.where(sent == self.start_token)[0][-1]
        end_token_number = len(torch.where(sent[last_start_tokens:] == self.end_token)[0])
        # GCR not closed
        if end_token_number == 0:
            self.last_start_token = last_start_tokens
            return True, last_start_tokens
        else:
            self.last_start_token = None
            return False, len(sent)
    
    def allowed_tokens_fn(self, batch_id: int, sent: torch.Tensor):

        constrained_flag = self.constrained_flag

        # Check if enter the constrained decoding
        if self.start_token is not None and self.end_token is not None:
            constrained_flag, L_input = self.check_constrained_flag(sent)
        else:                
            if self.L_input is None:
                self.L_input = len(sent)
            L_input = self.L_input

        # Set constrained_flag to False, meaning do not enter constrained decoding
        # constrained_flag = False

        allow_tokens = self.all_tokens
        if constrained_flag:
            # Obtain candidate tokens under Trie constraints
            allow_tokens = self.trie.get(sent.tolist()[L_input:])

            # If Trie returns empty, open the full vocabulary to avoid deadlock
            if len(allow_tokens) == 0:
                return self.all_tokens

            # ==========  New visualization / debugging code  ==========
            # Concatenate the current prefix with each candidate token and decode to human-readable text
            # Control the number of prints to avoid excessive output
            # max_print = 200000
            # token_list = allow_tokens if isinstance(allow_tokens, list) else [allow_tokens]
            # for idx, tok_id in enumerate(token_list[:max_print]):
            #     # Concatenate sequence: sent + [tok_id]
            #     candidate_seq = torch.cat(
            #         (sent, torch.tensor([tok_id], device=sent.device))
            #     )
            #     # Only decode the part starting from L_input; remove slicing if you want to see the full prompt
            #     decoded_text = self.tokenizer.decode(
            #         candidate_seq[:]
            #     )
                # print(f"[batch {batch_id}] next_id={tok_id} -> \"{self.tokenizer.decode(tok_id)}\"")
                # print(f"[Current Text:] \"{decoded_text}\"")
                # print("--------------------------------------------------")
        else:
            # print("There is no constraint!")
            # decoded_text = self.tokenizer.decode(sent)
            # print(f"[Current Text:] \"{decoded_text}\"")
            # print("--------------------------------------------------")
            pass
            # ==========  End of visualization code  ==========
        
        # Potential issue: what happens when the number of candidates in the search space is smaller than the beam size?

        # TODO: Add reason error handling code here
        # step1: Obtain the decoded text (reasoning path) of beams in the current batch
        # step2: Check whether the reasoning path contains reasoning errors
        # step2.1: If there is a reasoning error, set allow_tokens to an error/end symbol (stop generation)
        # step2.2: If there is no reasoning error, continue normally

        # if allow_tokens != self.all_tokens:
        #     token_list = allow_tokens if isinstance(allow_tokens, list) else [allow_tokens]
        #     remove_idx = []
        #     for idx, tok_id in enumerate(token_list):
        #         # print(f"[batch {batch_id}] next_id={tok_id} -> \"{self.tokenizer.decode(tok_id)}\"")
        #         if self.tokenizer.decode(tok_id).strip() == "->" :
        #             # Only decode the part starting from L_input; remove slicing if you want to see the full prompt
        #             decoded_text = self.tokenizer.decode(
        #                 sent[L_input:], skip_special_tokens=True
        #             )
        #             print(f"[batch {batch_id}] next_id={tok_id} -> \"{decoded_text}\"")
        #             # Count how many '->' in decoded_text; if the count is even, then evaluate
        #             arrow_count = decoded_text.count("->")
        #             if arrow_count % 2 ==0 and arrow_count!= 0:
        #                 if arrow_count >= 4:
        #                     print("A 2-hop reasoning path exists!")
        #                 # step2
        #                 judgement = self.check_reason_errors(decoded_text)
        #                 # If a reasoning error is detected, end this path with an error end TOKEN
        #                 if not judgement:
        #                     if isinstance(allow_tokens, list):
        #                         remove_idx.append(idx)
        #                     else:
        #                         allow_tokens = [self.tokenizer.eos_token_id]
        #                         break
        #             else:
        #                 continue
        #         else:   
        #             continue
        #     allow_tokens = [token for i, token in enumerate(allow_tokens) if i not in remove_idx]
        #     # If allow_tokens is still empty after the reason error check, set it to eos_token_id
        #     if len(allow_tokens) == 0:
        #         allow_tokens = [self.tokenizer.eos_token_id]
        
        # No visualization
        return allow_tokens
        # With visualization
        # return allow_tokens, self.tokenizer