# ========================= 1. Environment Setup and Library Import =========================
import sys
import os
import requests
import json
import re
import time
import subprocess
import pandas as pd
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import traceback
from pathlib import Path
from typing import Optional, Union, List, Dict, Any


# ------------------------- File Path Configuration -----------------------------
root_dir = ''
rabit_jar = os.path.join(root_dir, 'aliyun/rabit250 - new_extracted/rabit250 - new/out/artifacts/rabit250____jar/rabit250 - new.jar')
input_ba_path = os.path.join(root_dir, 'aliyun/output.ba')
comparison_automaton = os.path.join(root_dir, 'aliyun/Baserule.ba')
Outputfilename = os.path.join(root_dir, 'aliyun/Output.ba')

print(f"RABIT JAR path: {rabit_jar}")
print(f"Input BA file path: {input_ba_path}")
print(f"Comparison automaton path: {comparison_automaton}")

# Check if key files exist
def check_file_exists(file_path):
    if os.path.isfile(file_path):
        print(f"✅ {file_path} exists")
    else:
        print(f"❌ {file_path} does not exist")

check_file_exists(rabit_jar)
check_file_exists(comparison_automaton)

class LLM_Generator:
    """
    Unified wrapper for GPT-API / Local LLM inference
    """
    def __init__(
        self,
        mode: str = "gpt",
        api_key: Optional[str] = None,
        model_name: str = "gpt-4o",
        base_url: str = "https://api.openai.com/v1",
        model_path: Optional[str] = None,
        device: Union[str, int, None] = None,
        device_map: Union[str, dict, None] = "auto",
        torch_dtype: torch.dtype = torch.bfloat16,
        load_in_4bit: bool = True,
        terminator_tokens: Optional[List[str]] = None,
    ):
        self.mode = mode.lower()
        self.model_name = model_name
        self.base_url = base_url.rstrip("/")
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.terminator_tokens = terminator_tokens or []

        if self.mode == "gpt":
            if not api_key:
                raise ValueError("Please provide GPT API key")
            self.api_key = api_key
            print("✅ GPT-API mode ready")
        elif self.mode == "local":
            if model_path is None:
                raise ValueError("For local mode, please provide model_path")
            try:
                self.text_gen = pipeline(
                    "text-generation",
                    model=model_path,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                    model_kwargs={"load_in_4bit": load_in_4bit},
                )
                self.tokenizer = self.text_gen.tokenizer
                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.eos_token
                self.default_terminators = [
                    self.tokenizer.eos_token_id,
                    *[self.tokenizer.convert_tokens_to_ids(tk) for tk in self.terminator_tokens]
                ]
                print(f"✅ Local model loaded to {self.device.upper()} (dtype={torch_dtype}, 4bit={load_in_4bit})")
            except Exception as e:
                raise RuntimeError(f"Local model loading failed → {e}")
        else:
            raise ValueError("mode only supports 'gpt' or 'local'")

    def generate(self, prompt: str, max_tokens: int = 400, temperature: float = 0.7, 
                 top_p: float = 0.9, stop: Optional[List[str]] = None, **api_params):
        if self.mode == "gpt":
            return self._gpt_generate(prompt, max_tokens, temperature, top_p, stop, **api_params)
        return self._local_generate(prompt, max_tokens, temperature, top_p, stop, **api_params)

    def _gpt_generate(self, prompt, max_tokens, temperature, top_p, stop, **api_params):
        url = f"https://gpt-api.hkust-gz.edu.cn/v1/chat/completions"
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.api_key}",
        }
        payload = dict(
            model=self.model_name,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            stop=stop,
            **api_params,
        )
        try:
            r = requests.post(url, headers=headers, json=payload, timeout=120)
            r.raise_for_status()
            return r.json()["choices"][0]["message"]["content"].strip()
        except Exception as e:
            raise RuntimeError(f"GPT request failed → {e}")

    def _local_generate(self, prompt, max_tokens, temperature, top_p, stop, **_):
        stop_ids = self.default_terminators.copy()
        if stop:
            stop_ids += [self.tokenizer.convert_tokens_to_ids(t) for t in stop]

        res = self.text_gen(
            prompt,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=stop_ids,
            pad_token_id=self.tokenizer.eos_token_id,
            return_full_text=False,
        )
        return res[0]["generated_text"].strip()
        
# ------------------------- Model configuration -----------------------------
# GPT model configuration
gpt_gen = LLM_Generator(
    mode='gpt',
    api_key="",
    model_name="gpt-4"
)

# Local model configuration
model_id = "aliyun/trained_model_10/"
text_generator = pipeline(
    "text-generation",
    model=model_id,
    device="cuda:1",
    torch_dtype=torch.bfloat16,
    model_kwargs={"load_in_4bit": False}
)

# Define terminators
terminators = [
    text_generator.tokenizer.eos_token_id,
    text_generator.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


# RABIT command configuration
rabit_command = [
    'java', '-jar', rabit_jar, input_ba_path, comparison_automaton, '-fastc'
]


# ---------------------- Syntax check and correction -------------------------
def check_syntactic_correctness(ltl_formula):
    """
    Use stack to check if parentheses in LTL match
    """
    stack = []
    for char in ltl_formula:
        if char == '(':
            stack.append(char)
        elif char == ')':
            if not stack:
                return False
            stack.pop()
    return not stack

def correct_syntactic_errors(ltl_formula):
    """
    Call local model to correct parentheses and other syntax errors
    """
    prompt_fix = f"""
    The following LTL formula has syntax errors, please correct it:
    {ltl_formula}
    You need to pay attention to the matching of parentheses.
    Curent syntatic ERROR TYPE is "Bracket doesn't match". 
    LTL Expression Requirements: Include multiple nested "eventually" (F), "globally" (G), "next" (X), "implies"(->), "equivalent"(<->), "and"(&), "or"(|),"not"(!) operators.
    AP Expression Requirements: Words in a single AtomicProposition should be connected using underlines (_). Any other punctuation (like “%”) is not allowed.
    You need to check whether the operator in LTL formula is valid and if it doesn't valid,correct it.
    You don't need to output any other things like:"LTL1","Here is the corrected LTL1:","The corrected LTL1 is:","The LTL1 is:","The LTL1 is as follows:","The LTL1 is as follows below:","The LTL1 is as follows below:" and so on.
    Please output only a pure LTL expression, no extra text.Only output the corrected LTL formula without any additional information or explanation.This is very important, please only output the LTL expression.
    """
    
    messages = [{"role": "user", "content": prompt_fix}]

    formatted_prompt = text_generator.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    outputs = text_generator(
        formatted_prompt,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=text_generator.tokenizer.eos_token_id
    )

    result = outputs[0]["generated_text"][len(formatted_prompt):].strip()
    
    if isinstance(result, dict):
        corrected_ltl = result.get("generated_text", "").strip()
    elif isinstance(result, str):
        corrected_ltl = result.strip()
    else:
        corrected_ltl = ""

    return corrected_ltl

def correct_ltl_formula(ltl_formula):
    """
    Loop calling correct_syntactic_errors until parentheses match correctly
    """
    iteraction = 0
    while True:
        iteraction+=1
        if check_syntactic_correctness(ltl_formula) or iteraction > 6:
            break
        else:
            new_ltl = correct_syntactic_errors(ltl_formula)
            if new_ltl is None:
                print("Error correcting LTL formula")
                return None
            ltl_formula = new_ltl
            print(f"Syntax correction: {ltl_formula}")
        print("-" * 50)
    return ltl_formula


# ---------------------- LTL generation and transformation ----------------------
def transform_nl_to_ltl_with_local_model(input_path, output_path, debug=True):
    """
    Use local model to convert natural language instructions to LTL formulas
    """
    print(f"📖 Reading input file: {input_path}")
    with open(input_path, 'r') as f:
        data = json.load(f)
    
    print(f"📊 Total {len(data)} records")
    
    # Process each instruction
    for i, item in enumerate(data, 1):
        # Skip entries that already have LTL
        if item.get("ltl"):
            if debug:
                print(f"⏭️  Entry #{i}: Already has LTL, skipping")
            continue
            
        nl_instruction = item["nl"]
        if debug:
            print(f"\n🔄 Processing entry #{i}")
            print(f"Natural language: {nl_instruction}")
        
        prompt_template = (
            f"Transform the following natural language driving instruction into an LTL formula as a professional LTL expert: {nl_instruction}\n"
            "LTL Expression Requirements: Include multiple nested 'eventually' (F), 'globally' (G), 'next' (X), 'implies'(->), 'equivalent'(<->), 'and'(&), 'or'(|),'not'(!) operators.\n"
            "As a reminder, you must use the punctuations listed above as logical operators. For example, use '!' to express 'not', instead of using the natural language 'not' itself directly."
            "AP Expression Requirements: Words in a single AtomicProposition should be connected using underlines (_). Any other punctuation (like '%','~','=','°', etc.) is not allowed to appear in the APs.\n"
            "Logically represent the navigation sequence, ensuring the expression accurately reflects the instructions and conditions.\n"
            "pay attention to the matching relationships between the parentheses.\n"
            "Must obey: Output only the raw LTL formula in one line with no explanation, no formatting, no quotes, no extra information, text display symbols, such as (*#). Sentences like 'Here is the transformed LTL formula:' are not allowed. Only pure LTL formula is needed.\n"
            "One example is as following:\n natural language: Go straight 1.2km through tunnel, turn left after light tower, final stop is 200m right.\n output: G(darkness -> headlights_on) & F(straight_1.2km -> F(lighttower & X(left_turn))) -> X(right_200m -> F(arrive))"
        )
        
        messages = [{"role": "user", "content": prompt_template.format(nl=nl_instruction)}]
        
        # Apply chat template
        prompt = text_generator.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        
        # Generate LTL
        outputs = text_generator(
            prompt,
            max_new_tokens=256,
            eos_token_id=terminators,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
            pad_token_id=text_generator.tokenizer.eos_token_id
        )
        
        # Extract pure LTL expression
        generated_text = outputs[0]["generated_text"]
        ltl_output = generated_text[len(prompt):].strip()
        
        # Update entry
        item["ltl"] = ltl_output
        
        if debug:
            print(f"Generated LTL: {ltl_output}")
            print("-" * 60)
    
    # Save results
    with open(output_path, 'w') as f:
        json.dump(data, f, indent=2)
    print(f"\n✅ Processing complete! Results saved to {output_path}")

# --------------------- LTL -> HOA -> BA ----------------------
def run_ltl2tgba(ltl_formula, debug=True, full_output=True):
    """
    Use ltl2tgba tool to convert LTL formula to Büchi automaton
    
    :param ltl_formula: LTL formula
    :param debug: Whether to show debug information
    :param full_output: Whether to show complete SPOT output (default True)
    """
    ltl2tgba_path = os.path.expanduser("~/spot/bin/ltl2tgba")
    command = [ltl2tgba_path, "-f", ltl_formula]
    
    if debug:
        print(f"🔧 Executing command: {' '.join(command)}")
        print(f"📝 Input LTL formula: {ltl_formula}")
    
    try:
        process_result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True)
        output_str = process_result.stdout
        error_str = process_result.stderr
        
        if debug:
            print("✅ ltl2tgba execution successful")
            print(f"📊 Total output length: {len(output_str)} characters")
            
            if full_output:
                print("\n" + "="*80)
                print("🔍 Complete SPOT output (HOA format):")
                print("="*80)
                print(output_str)
                print("="*80)
                
                if error_str:
                    print("\n⚠️  STDERR output:")
                    print("-"*40)
                    print(error_str)
                    print("-"*40)
            else:
                print(f"Output length: {len(output_str)} characters")
        
        # Convert to BA format
        ba_result = hoa_to_ba(output_str)
        
        # Save to file
        save_to_ba_file(ba_result, input_ba_path)
        
        if debug:
            print("✅ Conversion complete, results saved as output.ba")
            if full_output:
                print("\n" + "="*60)
                print("🔍 Complete BA format output:")
                print("="*60)
                print(ba_result)
                print("="*60)
            else:
                print("Conversion result preview:")
                print(ba_result[:200] + "..." if len(ba_result) > 200 else ba_result)
        
        return ba_result
        
    except subprocess.CalledProcessError as e:
        print("❌ ltl2tgba execution failed:")
        print(f"Error code: {e.returncode}")
        print(f"STDERR: {e.stderr}")
        return None
    except FileNotFoundError as e:
        print("❌ ltl2tgba executable not found, please check if the path is correct:")
        print(f"Error details: {e}")
        return None

def save_to_ba_file(content, file_path):
    """
    Save BA content to file
    """
    try:
        with open(file_path, 'w', encoding='utf-8') as file:
            file.write(content)
        print(f"✅ File saved to {file_path}")
    except Exception as e:
        print(f"❌ Error occurred while saving file: {e}")

def parse_condition(condition_str):
    """
    Parse condition string
    """
    condition_str = condition_str.strip()
    if condition_str == 't':
        return 'True'
    literals = condition_str.split('&')
    literals = sorted(literal.strip() for literal in literals)
    return '&'.join(literals)

def hoa_to_ba(hoa_content):
    lines = hoa_content.strip().splitlines()

    # Initialize variable
    ap_table = []
    initial_state = None
    accepting_states = []
    transitions = []

    parsing_states = False
    acceptance_condition = None

    def parse_label(label):
        """Parse the label and process the operators and atomic propositions within it."""
        
        def process_condition(condition):
            """Convert the atomic propositions into a readable format."""
            result = ""
            i = 0
            while i < len(condition):
                char = condition[i]
                if char.isdigit():  
                    result += ap_table[int(char)]  
                elif char in "!&()":  
                    result += char
                elif char == " ":
                    i += 1
                    continue
                else:
                    raise ValueError(f"Unexpected character in label: {char}")
                i += 1
            return result

        def split_conditions(expression, operator):
            """Split the expression based on the given operators, ensuring the brackets are matched."""
            parts = []
            depth = 0
            current = []
            for char in expression:
                if char == "(":
                    depth += 1
                elif char == ")":
                    depth -= 1
                if char == operator and depth == 0:
                    parts.append("".join(current).strip())
                    current = []
                else:
                    current.append(char)
            parts.append("".join(current).strip())
            return parts


        if "|" in label: 
            or_conditions = split_conditions(label, "|")
            parsed_conditions = [process_condition(cond.strip()) for cond in or_conditions]
            return parsed_conditions
        elif "&" in label:
            and_conditions = split_conditions(label, "&")
            return ["&".join([process_condition(cond.strip()) for cond in and_conditions])]
        else:
            return [process_condition(label)]

    for line in lines:
        line = line.strip()

       
        if not line or line.startswith("HOA") or line.startswith("--") or line.startswith("name"):
            continue

        
        if line.startswith("States:"):
            total_states = int(line.split()[1])

        
        elif line.startswith("Start:"):
            initial_state = f"[{line.split()[1]}]"

       
        elif line.startswith("Acceptance:"):
            acceptance_condition = line.split()[1:]

        
        elif line.startswith("AP:"):
            ap_table = line.split(" ")[2:]
            ap_table = [ap.strip('"') for ap in ap_table]  

 
        elif line.startswith("State:"):
            parsing_states = True
            # print('sadasdasdas',line)
            state_info = line.split()
            current_state = f"[{state_info[1]}]"
            # print('sssssssssss', "1 Inf(0)" in" ".join(acceptance_condition))
            # print(state_info)
            if len(state_info) > 1 and "{0}" in state_info and "1 Inf(0)" in " ".join(acceptance_condition):
                # print('11111111',current_state)
                accepting_states.append(current_state)
              
            elif len(state_info) > 1 and "{1}" in state_info and "1 Inf(1)" in " ".join(acceptance_condition):
                accepting_states.append(current_state)
           
            elif len(state_info) > 1 and "&" in " ".join(acceptance_condition):
                import re
                numbers = re.findall(r'Inf\((\d+)\)', " ".join(acceptance_condition))
                numbers = [int(num) for num in numbers]

                parts =state_info


                result = []
                temp = ''
                for part in parts:
                    if part.startswith('{'):

                        temp += part
                    elif part.endswith('}'):
      
                        temp += ' ' + part if temp else part
                        result.append(temp)
                        temp = ''
                    elif temp:
                    
                        temp += ' ' + part
                    else:
                        result.append(part)

                state_info_new = result[-1].replace("{", "").replace("}", "")
                arr_new = []
                for i in state_info_new:
                    if i != ' ':
                        arr_new.append(int(i))
                if set(arr_new).issubset(set(numbers)):
                    accepting_states.append(current_state)
         

        elif parsing_states and line.startswith("["):
            try:
                
                # print(line)
                if '{' in line:
                    state_info1 = line.split()
                    current_state1 = f"[{state_info1[-2]}]"
                    if len(state_info1) > 1 and "{0}" in state_info1 and "1 Inf(0)" in " ".join(acceptance_condition):
                        accepting_states.append(current_state1)
                    elif len(state_info1) > 1 and "{1}" in state_info1 and "1 Inf(1)" in " ".join(acceptance_condition):
                        accepting_states.append(current_state)
                import re
                if ' | ' in line:
                    line = re.sub(r'\s*\|\s*', '|', line)
                transition_parts = line.split()
                label = transition_parts[0]
                destination = transition_parts[1]
                destination_state = f"[{destination}]"

                
                if label == "[t]":
                    transitions.append(f"t,{current_state}->{destination_state}")
                    if "0 t" in " ".join(acceptance_condition):
                        accepting_states.append(f"[{transition_parts[1]}]")
                else:
                    
                    readable_labels = parse_label(label.strip("[]"))  
                    
                    for sub_label in readable_labels:
                        transitions.append(f"{sub_label},{current_state}->{destination_state}")

            except (IndexError, ValueError) as e:
                raise RuntimeError(f"Error processing transition '{line}': {e}")

    
    accepting_states = list(set(accepting_states))

    
    if not initial_state and transitions:
        initial_state = transitions[0].split(",")[1].split("->")[0]

    
    ba_content = f"{initial_state}\n"
    ba_content += "\n".join(transitions) + "\n"
    ba_content += "\n".join(accepting_states) + "\n"

    return ba_content


# --------------------- RABIT  -------------------------
def check_file_exists(file_path):
    if os.path.isfile(file_path):
        print(f"{file_path} exists.")
    else:
        print(f"{file_path} does not exist.")

check_file_exists(rabit_jar)
check_file_exists(input_ba_path)
check_file_exists(comparison_automaton)

print(f"RABIT.jar path: {rabit_jar}")
print(f"Input automaton path: {input_ba_path}")
print(f"Comparison automaton path: {comparison_automaton}")

rabit_command = [
    'java', '-jar', rabit_jar, input_ba_path, comparison_automaton, '-fastc'
]

def run_rabit_and_check_inclusion(debug=True):
    """
    Use RABIT tool to check language inclusion
    """
    if debug:
        print("🔍 Using RABIT to check language inclusion...")
        print(f"Command: {' '.join(rabit_command)}")
    
    try:
        rabit_process = subprocess.run(rabit_command, capture_output=True, text=True, timeout=30)
        output = rabit_process.stdout
        error_output = rabit_process.stderr

        if debug:
            print("\n📊 RABIT tool output:")
            print(output)
            if error_output:
                print("\n⚠️  RABIT tool error:")
                print(error_output)

        if "Included" in output:
            if debug:
                print("✅ The language of the first automaton is included in the language of the second automaton")
            return True, output
        else:
            if debug:
                print("❌ The language of the first automaton is not included in the language of the second automaton")
            return False, output
            
    except subprocess.TimeoutExpired:
        print("⏰ RABIT execution timeout")
        return False, "RABIT execution timeout"
    except Exception as e:
        print(f"❌ RABIT execution error: {e}")
        return False, str(e)

def extract_atomic_propositions(ltl_formula):
    """
    Extract atomic propositions from LTL formula
    """
    ltl2tgba_path = os.path.expanduser("~/spot/bin/ltl2tgba")
    command = [ltl2tgba_path, "-f", ltl_formula]
    
    try:
        process_result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True)
        output_text = process_result.stdout
        ap_match = re.search(r'AP:\s+\d+\s+((?:"[^"]*"\s*)+)', output_text)
        
        if not ap_match:
            print("AP section not found")
            return []
        
        atomic_propositions = ap_match.group(1).split('" "')
        atomic_propositions = [prop.replace('"', '').strip() for prop in atomic_propositions]
        
        return atomic_propositions
    except Exception as e:
        print(f"Error extracting atomic propositions: {e}")
        return []

# ----------------  correct_input_with_gpt  ----------------
def gpt_replace_AP(atomic_proposition_library, LTL1, debug=True):
    """
    Use GPT API model to replace atomic propositions in LTL with similar ones from predefined library
    """
    prompt = (
        f"""
Please match the following LTL's atomic propositions: {LTL1}
to a pre-defined library of atomic propositions: {atomic_proposition_library}.
If any proposition is similar to one in the library, replace it with the library's proposition.
Only modify the atomic propositions. Do not change any logical structure or operators.
Return the updated raw LTL formula only. No explanation, no formatting, no punctuation.
"""
    )
    result = gpt_gen.generate(prompt, max_tokens=256)
    
    if isinstance(result, str):
        updated_ltl = result.strip()
    elif isinstance(result, dict):
        updated_ltl = result.get("generated_text", "").strip()
    else:
        updated_ltl = ""

    if debug:
        print(f"Updated LTL: {updated_ltl}")
    
    return updated_ltl

def gpt_understand_rabit_output(LTL1, input_BA, comparison_BA, checking_output, comparison_LTL, nl_instruction="", debug=True):
    """
    Use GPT to understand RABIT output and provide correction suggestions, considering natural language semantics
    """
    nl_context = f"\n- Original_Natural_Language_Instruction: {nl_instruction}Logically represent the navigation sequence, ensuring the expression accurately reflects the instructions and conditions. Ensure that the movements of the vehicle are reflected in the LTL." if nl_instruction else ""
    
    prompt = (
        f"""
Task: Refine LTL Formula Using RABIT Tool Output

Analyze the RABIT tool output to identify counterexamples and propose flexible changes to LTL1, ensuring the language of `input_BA` is a subset of `comparison_BA`. You need to work as a expertise professor in LTL and Büchi Automata.

Inputs:
- LTL1_Formula: {LTL1}
- Input_BA: {input_BA}
- Comparison_BA: {comparison_BA}
- RABIT_Output: {checking_output}
- Comparison_LTL: {comparison_LTL}{nl_context}

**IMPORTANT**: If natural language instruction is provided, ensure that any suggested modifications to LTL1 preserve the semantic correctness and intent of the original natural language instruction. The corrected LTL should still accurately represent the driving behavior described in the natural language.

Steps:
1. Analyze Counterexamples: For each counterexample, explain why it's accepted by `input_BA` but rejected by `comparison_BA`, and identify the issue in LTL1 causing the discrepancy.
2. Diagnose LTL1 Issues: Identify problematic parts of LTL1 (e.g., transitions, constraints, temporal logic) and their interaction with states and transitions in both automata.
3. Consider Natural Language Semantics: If provided, ensure proposed changes maintain semantic alignment with the original natural language instruction.
4. Propose Adjustments: Suggest minimal but flexible changes to LTL1 to resolve issues, keeping the formula simple and avoiding unnecessary complexity. Provide guidance for correction but not a concrete LTL formula.
5. Ensure Alignment: Confirm that proposed changes address all counterexamples and improve LTL1's compatibility with `comparison_BA`.

Output: 
1. Counterexample_Analysis: "Sequence": "counterexample_sequence", "Reason": "why accepted by input_BA but rejected by comparison_BA", "Issue_in_LTL1": "issue in LTL1"
2. Proposed_Adjustments: "Adjustment": "change to LTL1", "Justification": "how it resolves the discrepancy"
3. Natural_Language_Alignment: "Semantic_Check": "ensure changes preserve original natural language intent"
4. General_Guidance: "Summary of the approach and its broader impact"

Tips:
1. The 't' in BA is a specially string, it't not a true alphabet string, but a signal of 'unconditional', so don't add anything about 't' to the LTL directly, instead you should think about its logic and change that.
2. Don't add operators like '<', '>' and '=', which are not allowed by LTL's syntactic rules.
        """
    )
    result = gpt_gen.generate(prompt, max_tokens=2048)
    
    if debug:
        print("💡 GPT analysis result:")
        print(result)
    
    return result

def gpt_correct_ltl(LTL1, understanding_output, nl_instruction="", debug=True):
    """
    Use local model to correct LTL based on understanding output, maintaining natural language semantic correctness
    """
    nl_context = f"""
        7) SEMANTIC PRESERVATION: The original natural language instruction is: "{nl_instruction}"
           Logically represent the navigation sequence, ensuring the expression accurately reflects the instructions and conditions. Ensure that the movements of the vehicle are reflected in the LTL.
        """ if nl_instruction else ""
    
    prompt = (
        f"""
        The current LTL1:
        {LTL1}
        does not satisfy the inclusion check based on the Buchi automaton comparison. Below is the analysis and revision guidance:
        {understanding_output}

        Your task is to modify and simplify LTL1 to ensure that it satisfies the inclusion check while keeping the formula as concise and flexible as possible. Adjustments should focus on resolving the specific issues identified in the guidance rather than adding restrictive or overly complex conditions. Aim for an intuitive and streamlined solution that aligns LTL1 with the comparison Buchi automaton.

        Instructions:
        1)Modify LTL1 based on the provided analysis.
        2)Ensure the revised formula directly addresses the counterexample(s) while improving alignment with the comparison Buchi automaton.
        3)Avoid adding unnecessary constraints or increasing the formula's complexity—prioritize simplicity and precision.
        4)The 't' in BA is a specially string, it't not a true alphabet string, but a signal of 'unconditional', so don't add anything about 't' to the LTL directly, instead you should think about its logic and change that.
        5)Don't add operators like '<', '>' and '=', which are not allowed by LTL's syntactic rules.
        6)Pay attention LTL should only  contain :"goStraight","turnLeft","turnRight","reachDestination"
        {nl_context}
        
        Output Format:
        Must obey: Output only the raw updated LTL formula in one line with no explanation, no formatting, no quotes, no extra information, text display symbols, such as (*#).\n
        Don't output any other thingss like:"LTL1","Here is the corrected LTL1:","The corrected LTL1 is:","The LTL1 is:","The LTL1 is as follows:","The LTL1 is as follows below:","The LTL1 is as follows below:" and so on.
        """
    )
    
    messages = [{"role": "user", "content": prompt}]

    formatted_prompt = text_generator.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    if debug:
        print("🔧 Using local model to correct LTL...")
        if nl_instruction:
            print(f"📝 Maintaining semantic consistency: {nl_instruction}")

    outputs = text_generator(
        formatted_prompt,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=text_generator.tokenizer.eos_token_id
    )
    
    result = outputs[0]["generated_text"][len(formatted_prompt):].strip()
    
    #result = gpt_gen.generate(prompt, max_tokens=1024)
    if isinstance(result, dict):
        corrected_ltl = result.get("generated_text", "").strip()
    elif isinstance(result, str):
        corrected_ltl = result.strip()
    else:
        corrected_ltl = ""

    if debug:
        print(f"Corrected LTL: {corrected_ltl}")

    return corrected_ltl



def AutoSafeLTL_Method(ltl_3, nl_instruction="", debug=True, max_iterations=10):
    """
    Complete LTL verification method
    :param ltl_3: Initial LTL formula
    :param nl_instruction: Corresponding natural language instruction for semantic preservation
    :param debug: Whether to show debug information
    :param max_iterations: Maximum number of iterations
    :return: (is_included, final_ltl) - Inclusion result and final LTL formula
    """
    comparison_LTL = 'G((turnLeft | turnRight) -> (F(goStraight) | F(reachDestination)))'
    atomic_proposition_library = extract_atomic_propositions(comparison_LTL)

    if not ltl_3:
        print("❌ No LTL formula generated, exiting")
        return False, ""
    
    if debug:
        print(f"🎯 Start verifying LTL: {ltl_3}")
        if nl_instruction:
            print(f"📝 Corresponding natural language: {nl_instruction}")
        print(f"🔖 Atomic proposition library: {atomic_proposition_library}")
        print(f"🔄 Maximum iterations: {max_iterations}")
    
    # Replace atomic propositions
    ltl_3 = gpt_replace_AP(atomic_proposition_library, ltl_3, debug)

    # LTL -> BA
    if debug:
        print("🔄 Step 1: Convert LTL to Büchi automaton")
    run_ltl2tgba(ltl_3, debug)
    
    # Read comparison automaton
    with open(comparison_automaton, 'r') as f:
        autB_ba_text = f.read()

    current_ltl = ltl_3
    included = False
    iteration = 0

    while not included and iteration < max_iterations:
        iteration += 1
        print(f"\n{'='*20} Iteration {iteration}/{max_iterations} {'='*20}")

        # RABIT check
        if debug:
            print("🔄 Step 2: Use RABIT to check language inclusion")
        included, rabit_output = run_rabit_and_check_inclusion(debug)
        print("rabit_output:", rabit_output)
        if included:
            print("✅ Language inclusion check passed!")
            print(f"🎯 Final successful LTL: {current_ltl}")
            break

        # Read current input automaton
        with open(input_ba_path, 'r') as f:
            autA_ba_text = f.read()

        # GPT analyze RABIT output (including natural language semantics)
        if debug:
            print("🔄 Step 3: GPT analyze RABIT output")
        understanding_output = gpt_understand_rabit_output(
            current_ltl, autA_ba_text, autB_ba_text, rabit_output, comparison_LTL, nl_instruction, debug
        )
        
        # GPT correct LTL (maintaining natural language semantics)
        if debug:
            print("🔄 Step 4: Correct LTL formula")
        corrected_ltl = gpt_correct_ltl(current_ltl, understanding_output, nl_instruction, debug)
        
        if corrected_ltl is None:
            print("❌ Unable to get corrected LTL formula from GPT, exiting")
            break

        # Syntax check
        if debug:
            print("🔄 Step 5: Syntax check")
        corrected_ltl = correct_ltl_formula(corrected_ltl)
        
        if corrected_ltl is None:
            print("❌ Unable to correct LTL syntax, exiting")
            break

        print(f"🔧 LTL after {iteration} corrections: {corrected_ltl}")

        # Re-convert to BA
        run_ltl2tgba(corrected_ltl, debug)
        current_ltl = corrected_ltl

    if included:
        print(f"\n🎉 After {iteration} iterations, the language of the first automaton is now included in the language of the second automaton")
        print(f"🎯 Final successful LTL formula: {current_ltl}")
    else:
        print(f"\n❌ After {max_iterations} iterations, still unable to achieve 'inclusion' state")
        print(f"🔧 Last attempted LTL formula: {current_ltl}")
    
    return included, current_ltl

def process_ltl_verification(input_path, output_path, start_index=1, end_index=None, debug=True):
    """
    Process LTL verification and update status - supports specifying processing range
    
    :param input_path: Input JSON file path
    :param output_path: Output JSON file path  
    :param start_index: Start processing entry index (1-based)
    :param end_index: End processing entry index (1-based), None means process to the end
    :param debug: Whether to show debug information
    """
    try:
        # Read input file
        with open(input_path, 'r') as f:
            data = json.load(f)
        
        total_items = len(data)
        if end_index is None:
            end_index = total_items
        
        print(f"📖 Total {total_items} records")
        print(f"🎯 Processing range: entries {start_index} to {end_index}")
        
        success_count = 0
        error_count = 0
        skip_count = 0
        
        # Process records in specified range
        for i in range(start_index-1, min(end_index, total_items)):
            item = data[i]
            item_num = i + 1
            ltl_formula = item.get("ltl", "").strip()
            nl_instruction = item.get("nl", "").strip()  # Get natural language instruction
            
            print(f"\n{'🔄 Processing entry #' + str(item_num)}")
            print(f"Natural language: {nl_instruction[:100]}{'...' if len(nl_instruction) > 100 else ''}")
            
            # Skip entries with empty LTL or existing status
            if not ltl_formula or item.get("status"):
                print(f"⏭️  Skip entry #{item_num}: LTL is empty or status already exists")
                skip_count += 1
                continue
                
            try:
                print(f"🔤 Original LTL formula: {ltl_formula}")
                print("-" * 60)
                
                # Call verification function, pass natural language instruction
                is_included, final_ltl = AutoSafeLTL_Method(ltl_formula, nl_instruction, debug=debug)
                
                # Update status and final LTL
                item["status"] = "included" if is_included else "not included"
                if final_ltl and final_ltl != ltl_formula:
                    item["final_ltl"] = final_ltl  # Save final corrected LTL
                    print(f"🔧 Final LTL formula: {final_ltl}")
                else:
                    item["final_ltl"] = ltl_formula  # If no correction, save original LTL
                
                result_emoji = '✅' if is_included else '❌'
                print(f"\n{result_emoji} Entry #{item_num} verification result: {item['status']}")
                success_count += 1
                
                # Save after processing each entry to prevent data loss
                with open(output_path, 'w') as f:
                    json.dump(data, f, indent=2)
                
            except Exception as e:
                error_msg = f"Verification failed: {str(e)}"
                print(f"⚠️ Entry #{item_num} {error_msg}")
                if debug:
                    print(f"Error details:\n{traceback.format_exc()}")
                item["status"] = "error"
                item["final_ltl"] = ltl_formula  # Also save original LTL in error cases
                error_count += 1
            
            print("=" * 80)
        
        # Final save results
        with open(output_path, 'w') as f:
            json.dump(data, f, indent=2)
        
        print(f"\n📊 Processing completion statistics:")
        print(f"✅ Successfully processed: {success_count} entries")
        print(f"❌ Processing failed: {error_count} entries") 
        print(f"⏭️  Skipped: {skip_count} entries")
        print(f"💾 Results saved to {output_path}")
        
        return True
    
    except Exception as e:
        print(f"❌ Batch processing failed: {str(e)}")
        if debug:
            print(f"Error details:\n{traceback.format_exc()}")
        return False

# --------------------- Main -----------------------------
if __name__ == "__main__":
    
    
    # If need to generate LTL first
    transform_nl_to_ltl_with_local_model(
        input_path="",
        output_path=""
    )
    
    
    # Execute batch processing (now saves final LTL)
    success = process_ltl_verification(
        input_path = "", 
        output_path = "",
        start_index=1,
        end_index=None, 
    )
    
    if success:
        print("\nVerification complete! Check status field in output file")
    else:
        print("\nError occurred during processing, please check logs")
    
    #for idx, (ltl_3, raw_nl_2) in enumerate(pd.read_excel(benchmarkpath).iloc[:50, [1,0]].values, 1):
        #print(f" Processing Pair {idx}/50 ".center(80, "-"))
        #AutoSafeLTL_Method(ltl_3, raw_nl_2) 
        #print("\n" + "-"*80)