﻿from typing import List, Dict, Callable
import math
import os
import shutil
from transformers.generation.utils import GenerationConfig
from transformers import StoppingCriteriaList, GenerationConfig


import numpy as np
import time
import json

devices = [3]
d_str = ''
for d in devices:
    d_str += str(d) + ','
d_str = d_str[:-1]
os.environ["CUDA_VISIBLE_DEVICES"] = d_str


import torch
from torch.utils.data import DataLoader
USER_PATH = '/home/XXXX/home/XXXX/6_26_backup/fs_backup_feb13/'

os.environ["CURL_CA_BUNDLE"]=""
os.environ["REQUESTS_CA_BUNDLE"]=""
# os.environ['TRANSFORMERS_CACHE'] = USER_PATH + '/.cache/huggingface/hub'
cache_dir = '/work/XXXX/'
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME'] = cache_dir
# os.environ['HF_HUB_OFFLINE'] ='1'
FEWSHOT = open('/home/XXXX/clause/logic_cot_fewshot_quail_new.txt', 'r').read()
# import transformers

# import urllib3
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import argparse
from tqdm import tqdm
import time
import datetime
import warnings
import contextlib

import requests
from urllib3.exceptions import InsecureRequestWarning

old_merge_environment_settings = requests.Session.merge_environment_settings

@contextlib.contextmanager
def no_ssl_verification():
    opened_adapters = set()

    def merge_environment_settings(self, url, proxies, stream, verify, cert):
        # Verification happens only once per connection so we need to close
        # all the opened adapters once we're done. Otherwise, the effects of
        # verify=False persist beyond the end of this context manager.
        opened_adapters.add(self.get_adapter(url))

        settings = old_merge_environment_settings(self, url, proxies, stream, verify, cert)
        settings['verify'] = False

        return settings

    requests.Session.merge_environment_settings = merge_environment_settings

    try:
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', InsecureRequestWarning)
            yield
    finally:
        requests.Session.merge_environment_settings = old_merge_environment_settings

        for adapter in opened_adapters:
            try:
                adapter.close()
            except:
                pass

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = {'train_file_path': './example_data', 'test_file_path': './example_data', 'save_path': './../SFT_train_res', 'model_choice': 'meta-llama/Llama-2-13b-chat-hf', 
        'n_rows': 20, 'max_length': 1000,'temperature': 1, 'lr': 5e-05, 'weight_decay': 0.0, 'epochs': 10, 'max_grad_norm': 1.0, 'batch_size': 2, 'save_strategy': 'no', 'use_lora': True}
# args['model_choice'] = 'mistralai/Mistral-7B-Instruct-v0.3'
args['model_choice'] = 'meta-llama/Meta-Llama-3-8B-Instruct'

args = Struct(**args)

from transformers import StoppingCriteria
from transformers import StoppingCriteria

class StopOnNextStepTokens(StoppingCriteria):
    def __init__(self, stop_token_ids, prompt_len, device):
        self.stop_token_ids = torch.tensor(
            stop_token_ids, device=device, dtype=torch.long
        )
        self.prompt_len = prompt_len
    def __call__(self, input_ids, scores, **kwargs):
        # Check only newly generated tokens
        gen = input_ids[:, self.prompt_len:]

        # Stop if ALL sequences end with stop_token_ids
        for seq in gen:
            if seq.shape[0] < len(self.stop_token_ids):
                return False
            if not torch.equal(seq[-len(self.stop_token_ids):], self.stop_token_ids):
                return False
        return True

class StopOnNextStep(StoppingCriteria):
    def __init__(self, tokenizer, step_idx):
        self.tokenizer = tokenizer
        self.next_step = f"\n#{step_idx + 2}."
        self.final = "\nFinal Answer:"

    def __call__(self, input_ids, scores, **kwargs):
        # input_ids shape: (batch_size, seq_len)
        decoded = self.tokenizer.batch_decode(
            input_ids, skip_special_tokens=True
        )

        # Stop ONLY if *all* samples hit a boundary
        done = [
            (self.next_step in text) or (self.final in text)
            for text in decoded
        ]

        return all(done)

class LLM():
    def __init__(self):
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="bfloat16",
            bnb_4bit_use_double_quant=True,
        )
        with no_ssl_verification():
            

            
            self.tokenizer = AutoTokenizer.from_pretrained(
                    args.model_choice,
                    cache_dir = cache_dir,
                    token = ' ',
                    attn_implementation="flash_attention_2"

                    )
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id  
            self.tokenizer.padding_side = 'left'           

            self.model = AutoModelForCausalLM.from_pretrained(
                    args.model_choice, 
                    cache_dir = cache_dir,
                    quantization_config=quant_config,
                    device_map='auto',
                    token = ' ',
                    attn_implementation="flash_attention_2"
                    )

        self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def sentence_probabilities(self, sentences):
        with torch.no_grad():
            sentence_tokens = self.tokenizer(sentences, return_tensors='pt', padding=True)
            sentence_token_ids = sentence_tokens.input_ids.cuda()

            # Little hack to cut down inference time by 4-5x (leads to some imprecisions when using quantization)
            # Find the common prefix and run it through the model once, to save time
            first_different_token = (sentence_token_ids == sentence_token_ids[0, :].unsqueeze(0)).all(dim=0).long().argmin()
            common_prefix = sentence_token_ids[0, :first_different_token].unsqueeze(0)
            common_prefix_output = self.model(common_prefix, use_cache=True)
            common_prefix_key_values = tuple(tuple(tensor.expand(len(sentences), -1, -1, -1) for tensor in layer) 
                                             for layer in common_prefix_output.past_key_values)

            # Process the rest of the sentences
            rest_outputs = self.model(sentence_token_ids[:, first_different_token:], past_key_values=common_prefix_key_values)
            logits = torch.concat([common_prefix_output.logits.expand(len(sentences), -1, -1), rest_outputs.logits], dim=1).cuda()
            log_probs = logits.log_softmax(-1)
            log_probs = log_probs[:, :-1, :].gather(2, sentence_token_ids[:, 1:][:, :, None]).squeeze(-1).cuda()
            log_probs = (log_probs*sentence_tokens.attention_mask.cuda()[:, 1:]).sum(-1).cpu()
        return log_probs
    def nli(self, sentences, unknown):
        # true_probs = self.sentence_probabilities(sentences + " True.")
        # false_probs = self.sentence_probabilities(sentences + " False.")q[]
        # maybe_probs = self.sentence_probabilities(sentences + " Maybe.")
        if unknown:
            true_probs, maybe_probs, false_probs =  (self.sentence_probabilities([sentences + "(A)", sentences + "(B)", sentences + "(C)"]))
            return {'True': true_probs, 'Maybe': maybe_probs, 'False': false_probs}
        else:
            true_probs, false_probs =  (self.sentence_probabilities([sentences + "True", sentences + "False"]).softmax(-1))
            return {'True': true_probs, 'False': false_probs}
    def yn(self, sentences, norm=True, relaxed=False, obvious=False, fewshot=None, maybe=False):
        yns = []
        for sentence in sentences:
            if fewshot:
                sentence = fewshot + sentence
            
            if relaxed:
                yns.append(sentence + "Most likely")
                yns.append(sentence + "Not necessarily")
            elif obvious:
                yns.append(sentence + "obviously true.")
                yns.append(sentence + "not obviously true.")
            elif maybe:
                yns.append(sentence + "Yes")
                yns.append(sentence + "Maybe")
                yns.append(sentence + "No")
            else:
                yns.append(sentence + "Yes")
                yns.append(sentence + "No")
        # if norm:
        #     norms = self.sentence_probabilities(sentences)
        probs = []
        batch_size = 256
        for i in range(0, len(yns), batch_size):
            if i+batch_size < len(yns):
                probs += list(self.sentence_probabilities(yns[i:i+batch_size]))
            else: 
                probs += list(self.sentence_probabilities(yns[i:]))
        probs=torch.tensor(probs)
        #   
        # probs = (self.sentence_probabilities(yns))
        # probs = torch.exp(probs)
        pyes = []
        pno = []
        pmaybe = []
        if maybe:
            z = 3
        else:
            z = 2
        for i in range(0,len(probs), z):
            # if yns[i] not in cache.keys():
                # yes, no = self.sentence_probabilities([yns[i], yns[i+1]])
            
            if maybe:
                
                yes, maybe, no = probs[i], probs[i+1], probs[i+2]
                
                      
            else:
                yes, no = probs[i], probs[i+1]
            if norm:
                if maybe: 
                    y,m,n = torch.tensor([yes, maybe, no]).softmax(-1)
                else:
                    y,n = torch.tensor([yes, no]).softmax(-1)
              
                # cache[yns[i]] = y
                # cache[yns[i+1]] = n
                pyes.append(y)
                pno.append(n)
                if maybe:
                    pmaybe.append(m)
            else:
                pyes.append(1-yes/(yes + no))
            # else:
            #     y, n = cache[yns[i]], cache[yns[i+1]]
            #     pyes.append(y)
                # pno.append(n)/
        # print('cache length', len(cache))
        # if maybe:
        
        if maybe: return torch.stack([torch.tensor(pyes), torch.tensor(pmaybe), torch.tensor(pno)])
        return torch.tensor(pyes), torch.tensor(pmaybe), torch.tensor(pno)
    def complete(
        self,
        prompt,
        max_new=25,
        temp=1.0,
        topk=0,
        n_samples=1,   # ← NEW
    ):
        encode_ids = self.tokenizer(
            prompt,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=3000
        ).input_ids.cuda()

        gc = GenerationConfig(
            temperature=temp,
            do_sample=True,
            top_k=topk,
            max_new_tokens=max_new,
            num_return_sequences=n_samples,   # ← KEY LINE
            return_dict_in_generate=True,
            output_scores=False,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        outputs = self.model.generate(
            encode_ids,
            generation_config=gc,
        )

        responses = self.tokenizer.batch_decode(
            outputs.sequences,
            skip_special_tokens=True
        )

        return responses
    

    def generate_next_thought(self,
        prompt,
        step_idx,
        n_samples=8,
        max_new=100,
        temp=1.0,
        topk=50,
    ):
        # input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        # stopper = StopOnNextStep(self.tokenizer, step_idx)
    
        gc = GenerationConfig(
            do_sample=True,
            temperature=temp,
            top_k=topk,
            max_new_tokens=max_new,
            num_return_sequences=n_samples,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        stop_str = f"\n#{step_idx + 2}."
        stop_ids = self.tokenizer.encode(stop_str, add_special_tokens=False)

        stopper = StopOnNextStepTokens(
            stop_token_ids=stop_ids,
            prompt_len=input_ids.shape[1],
            device=input_ids.device,
        )

        out = self.model.generate(
            input_ids,
            generation_config=gc,
            stopping_criteria=StoppingCriteriaList([stopper]),
        )

        # out =self.model.generate(input_ids, generation_config=gc, stopping_criteria=StoppingCriteriaList([stopper]))

        texts = self.tokenizer.batch_decode(out, skip_special_tokens=True)

        return texts
    # def generate_next_thought_batch(self, prompts, step_idx, n_samples=8, max_new=100, temp=1, topk=0):
    #     """
    #     Generate next thoughts for multiple prompts at once.
    #     Returns list of lists: outer list = prompts, inner list = n_samples per prompt
    #     """
    #     # Tokenize all prompts with padding
    #     inputs = self.tokenizer(
    #         prompts, 
    #         return_tensors="pt", 
    #         padding=True,
    #         truncation=True,
    #         max_length=3000
    #     ).to('cuda')
        
    #     stop_str = f"\n#{step_idx + 2}."
    #     stop_ids = self.tokenizer.encode(stop_str, add_special_tokens=False)
        
    #     stopper = StopOnNextStepTokens(
    #         stop_token_ids=stop_ids,
    #         prompt_len=inputs.input_ids.shape[1],
    #         device=inputs.input_ids.device,
    #     )
        
    #     gc = GenerationConfig(
    #         do_sample=True,
    #         temperature=temp,
    #         top_k=topk,
    #         max_new_tokens=max_new,
    #         num_return_sequences=n_samples,
    #         pad_token_id=self.tokenizer.eos_token_id,
    #     )
        
    #     out = self.model.generate(
    #         inputs.input_ids,
    #         attention_mask=inputs.attention_mask,
    #         generation_config=gc,
    #         stopping_criteria=StoppingCriteriaList([stopper]),
    #     )
        
    #     texts = self.tokenizer.batch_decode(out, skip_special_tokens=True)
        
    #     # Reshape: texts is flat list of len(prompts) * n_samples
    #     # Reorganize into list of lists
    #     result = []
    #     for i in range(len(prompts)):
    #         result.append(texts[i * n_samples : (i + 1) * n_samples])
        
    #     return result




# import re
# from typing import List, Dict, Tuple, Set, Any
# from itertools import product

# class FOLToCNF:
#     def __init__(self):
#         self.symbol_counter = 0
#         self.mappings = {}
#         self.reverse_mappings = {}
#         self.facts = []
#         self.rules = []
#         self.formulas = []  # Non-quantified formulas (Or, And, Not combinations)
#         self.query = None
#         self.constants = set()
#         self.predicates = {}
        
#     def normalize_predicate(self, pred_name: str, args: List[str]) -> str:
#         """Normalize a predicate to a consistent string format (no spaces after commas)."""
#         return f"{pred_name}({','.join(args)})"
    
#     def get_symbol(self, name: str) -> str:
#         """Get or create a propositional symbol for a ground predicate."""
#         # Normalize the name to ensure consistency
#         name = name.replace(', ', ',').replace(' ,', ',')
        
#         if name not in self.mappings:
#             symbol = f"p{self.symbol_counter}"
#             self.symbol_counter += 1
#             self.mappings[name] = symbol
#             self.reverse_mappings[symbol] = name
#         return self.mappings[name]
    
#     def parse_predicate(self, pred_str: str) -> Tuple[str, List[str]]:
#         """Parse a predicate string like 'previous_stop(x, y)' into (name, [args])."""
#         match = re.match(r'(\w+)\(([^)]+)\)', pred_str.strip())
#         if match:
#             pred_name = match.group(1)
#             args = [arg.strip() for arg in match.group(2).split(',')]
#             return pred_name, args
#         return None, []
    
#     def extract_predicates_from_and(self, and_content: str) -> List[Tuple[str, List[str]]]:
#         """Extract all predicates from And(...) content."""
#         predicates = []
#         pattern = r'(\w+)\(([^)]+)\)'
#         for match in re.finditer(pattern, and_content):
#             pred_name = match.group(1)
#             args_str = match.group(2)
#             args = [arg.strip() for arg in args_str.split(',')]
#             predicates.append((pred_name, args))
#         return predicates
    
#     # def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
#     #     """
#     #     Parse a propositional formula and convert it to CNF.
#     #     Handles Or, And, Not combinations of ground predicates.
#     #     """
#     #     # This is a simplified conversion - for complex cases you'd need a full formula parser
#     #     # We'll handle the most common patterns
        
#     #     formula_str = formula_str.strip()
        
#     #     # Handle Or(And(...), And(...)) - disjunction of conjunctions
#     #     if formula_str.startswith('Or(And('):
#     #         return self.parse_or_of_ands(formula_str)
        
#     #     # Handle simple And(...) - already in CNF (single clause with all literals)
#     #     if formula_str.startswith('And('):
#     #         return self.parse_simple_and(formula_str)
        
#     #     # Handle simple Or(...) - single clause
#     #     if formula_str.startswith('Or('):
#     #         return self.parse_simple_or(formula_str)
        
#     #     return []
#     # def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
#     #     """
#     #     Parse a propositional formula and convert it to CNF.
#     #     Handles Or, And, Not combinations of ground predicates.
#     #     """
#     #     formula_str = formula_str.strip()
        
#     #     # Handle Or(And(...), And(...)) - disjunction of conjunctions
#     #     if formula_str.startswith('Or(And('):
#     #         return self.parse_or_of_ands(formula_str)
        
#     #     # Handle simple And(...) - already in CNF (single clause with all literals)
#     #     if formula_str.startswith('And('):
#     #         return self.parse_simple_and(formula_str)
        
#     #     # Handle simple Or(...) - single clause
#     #     if formula_str.startswith('Or('):
#     #         return self.parse_simple_or(formula_str)
        
#     #     # Handle simple Not(...) - single negated literal
#     #     if formula_str.startswith('Not('):
#     #         return self.parse_simple_not(formula_str)
        
#     #     # Handle single predicate (no logical operator)
#     #     if '(' in formula_str and not any(formula_str.startswith(op) for op in ['Or(', 'And(', 'Not(']):
#     #         literals = self.extract_literals_from_formula(formula_str)
#     #         if literals:
#     #             return [[literals[0]]]  # Single unit clause
        
#     #     return []

#     def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse a propositional formula and convert it to CNF.
#         Handles Or, And, Not, Implies combinations of ground predicates.
#         """
#         formula_str = formula_str.strip()
        
#         # Handle Implies - convert to Or(Not(A), B)
#         if formula_str.startswith('Implies('):
#             return self.parse_implies_formula(formula_str)
        
#         # Handle Or(And(...), And(...)) - disjunction of conjunctions
#         if formula_str.startswith('Or(And('):
#             return self.parse_or_of_ands(formula_str)
        
#         # Handle simple And(...) - already in CNF (single clause with all literals)
#         if formula_str.startswith('And('):
#             return self.parse_simple_and(formula_str)
        
#         # Handle simple Or(...) - single clause
#         if formula_str.startswith('Or('):
#             return self.parse_simple_or(formula_str)
        
#         # Handle simple Not(...) - single negated literal
#         if formula_str.startswith('Not('):
#             return self.parse_simple_not(formula_str)
        
#         # Handle single predicate (no logical operator)
#         if '(' in formula_str and not any(formula_str.startswith(op) for op in ['Or(', 'And(', 'Not(', 'Implies(']):
#             literals = self.extract_literals_from_formula(formula_str)
#             if literals:
#                 return [[literals[0]]]  # Single unit clause
        
#         return []

#     def parse_implies_formula(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse Implies(A, B) and convert to CNF.
#         Implies(A, B) ≡ Or(Not(A), B) ≡ ~A ∨ B
#         """
#         # Extract the two parts of the implication
#         inner_start = 8  # len('Implies(')
#         depth = 0
#         i = inner_start
#         comma_pos = -1
        
#         # Find the comma separating antecedent and consequent at depth 0
#         while i < len(formula_str):
#             if formula_str[i] == '(':
#                 depth += 1
#             elif formula_str[i] == ')':
#                 depth -= 1
#             elif formula_str[i] == ',' and depth == 0:
#                 comma_pos = i
#                 break
#             i += 1
        
#         if comma_pos == -1:
#             return []
        
#         antecedent = formula_str[inner_start:comma_pos].strip()
#         # Find the end of consequent (last closing paren)
#         consequent_start = comma_pos + 1
#         depth = 0
#         i = len(formula_str) - 1
#         while i >= consequent_start and formula_str[i] != ')':
#             i -= 1
#         consequent = formula_str[consequent_start:i].strip()
        
#         # Convert Implies(A, B) to Or(Not(A), B)
#         # First, negate the antecedent
#         if antecedent.startswith('Not('):
#             # Not(Not(A)) → A (double negation)
#             # Extract what's inside the Not
#             inner_start_not = 4
#             depth = 1
#             j = inner_start_not
#             while j < len(antecedent) and depth > 0:
#                 if antecedent[j] == '(':
#                     depth += 1
#                 elif antecedent[j] == ')':
#                     depth -= 1
#                 j += 1
#             antecedent_inner = antecedent[inner_start_not:j-1]
#             negated_antecedent = antecedent_inner
#         else:
#             negated_antecedent = f"Not({antecedent})"
        
#         # Now parse Or(Not(A), B)
#         or_formula = f"Or({negated_antecedent}, {consequent})"
#         return self.parse_simple_or(or_formula)

#     def parse_simple_not(self, formula_str: str) -> List[List[str]]:
#         """Parse Not(...) and apply De Morgan's laws if needed."""
#         # Extract what's inside the Not()
#         if not formula_str.startswith('Not('):
#             return []
        
#         inner_start = 4  # len('Not(')
#         # Find the content inside Not(...)
#         depth = 1
#         i = inner_start
#         while i < len(formula_str) and depth > 0:
#             if formula_str[i] == '(':
#                 depth += 1
#             elif formula_str[i] == ')':
#                 depth -= 1
#             i += 1
        
#         inner_content = formula_str[inner_start:i-1].strip()
        
#         # Apply De Morgan's laws
#         # Not(And(A, B, ...)) → Or(Not(A), Not(B), ...) → single clause [~A, ~B, ...]
#         if inner_content.startswith('And('):
#             literals = self.extract_literals_from_formula(inner_content)
#             # Negate all literals
#             negated_literals = []
#             for lit in literals:
#                 if lit.startswith('~'):
#                     negated_literals.append(lit[1:])  # Remove negation
#                 else:
#                     negated_literals.append(f"~{lit}")  # Add negation
#             return [negated_literals]  # Single clause with all negated literals
        
#         # Not(Or(A, B, ...)) → And(Not(A), Not(B), ...) → multiple unit clauses [[~A], [~B], ...]
#         elif inner_content.startswith('Or('):
#             literals = self.extract_literals_from_formula(inner_content)
#             # Negate all literals and return as separate clauses
#             negated_clauses = []
#             for lit in literals:
#                 if lit.startswith('~'):
#                     negated_clauses.append([lit[1:]])  # Remove negation
#                 else:
#                     negated_clauses.append([f"~{lit}"])  # Add negation
#             return negated_clauses
        
#         # Not(Not(A)) → A (double negation elimination)
#         elif inner_content.startswith('Not('):
#             # Recursively parse the inner Not, then negate the result
#             inner_cnf = self.parse_simple_not(inner_content)
#             # Negate the result
#             result = []
#             for clause in inner_cnf:
#                 negated_clause = []
#                 for lit in clause:
#                     if lit.startswith('~'):
#                         negated_clause.append(lit[1:])
#                     else:
#                         negated_clause.append(f"~{lit}")
#                 result.append(negated_clause)
#             return result
        
#         # Not(predicate(...)) → single negated literal
#         else:
#             literals = self.extract_literals_from_formula(formula_str)
#             if literals:
#                 return [[literals[0]]]  # Should be a single negated literal
#             return []
#     def parse_simple_and(self, formula_str: str) -> List[List[str]]:
#         """Parse And(a, b, c) -> [[a], [b], [c]]"""
#         literals = self.extract_literals_from_formula(formula_str)
#         return [[lit] for lit in literals]
    
#     def parse_simple_or(self, formula_str: str) -> List[List[str]]:
#         """Parse Or(a, b, c) -> [[a, b, c]]"""
#         literals = self.extract_literals_from_formula(formula_str)
#         return [literals]
    
#     def parse_or_of_ands(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse Or(And(a,b), And(c,d)) and convert to CNF.
#         This requires distribution: (a∧b)∨(c∧d) = (a∨c)∧(a∨d)∧(b∨c)∧(b∨d)
#         """
#         # Extract the And(...) components
#         and_groups = []
#         depth = 0
#         current = []
#         in_and = False
        
#         i = 3  # Skip "Or("
#         while i < len(formula_str):
#             if formula_str[i:i+4] == 'And(':
#                 in_and = True
#                 depth = 1
#                 i += 4
#                 start = i
#                 and_content = []
                
#                 while depth > 0 and i < len(formula_str):
#                     if formula_str[i] == '(':
#                         depth += 1
#                     elif formula_str[i] == ')':
#                         depth -= 1
#                     i += 1
                
#                 and_str = formula_str[start:i-1]
#                 and_literals = self.extract_literals_from_formula('And(' + and_str + ')')
#                 and_groups.append(and_literals)
#             else:
#                 i += 1
        
#         # Distribute: generate all combinations
#         if len(and_groups) == 0:
#             return []
        
#         # Create CNF by taking one literal from each And group
#         clauses = []
#         for combo in product(*and_groups):
#             clauses.append(list(combo))
        
#         return clauses
    
#     # def extract_literals_from_formula(self, formula_str: str) -> List[str]:
#     #     """Extract all literals from a formula, handling Not(...)."""
#     #     literals = []
        
#     #     # Find all Not(...) patterns
#     #     not_pattern = r'Not\((\w+)\(([^)]+)\)\)'
#     #     for match in re.finditer(not_pattern, formula_str):
#     #         pred_name = match.group(1)
#     #         args_str = match.group(2)
#     #         args = [arg.strip() for arg in args_str.split(',')]
#     #         pred_str = self.normalize_predicate(pred_name, args)
#     #         symbol = self.get_symbol(pred_str)
#     #         literals.append(f"~{symbol}")
            
#     #         # Track constants
#     #         self.constants.update(args)
        
#     #     # Find all positive predicates (not preceded by Not)
#     #     # Remove the Not(...) parts first
#     #     cleaned = re.sub(r'Not\([^)]+\)', '', formula_str)
        
#     #     pred_pattern = r'(\w+)\(([^)]+)\)'
#     #     for match in re.finditer(pred_pattern, cleaned):
#     #         pred_name = match.group(1)
#     #         # Skip logical operators
#     #         if pred_name in ['Or', 'And', 'Not', 'Implies', 'ForAll']:
#     #             continue
#     #         args_str = match.group(2)
#     #         args = [arg.strip() for arg in args_str.split(',')]
#     #         pred_str = self.normalize_predicate(pred_name, args)
#     #         symbol = self.get_symbol(pred_str)
#     #         literals.append(symbol)
            
#     #         # Track constants
#     #         self.constants.update(args)
        
#     #     return literals
#     def extract_literals_from_formula(self, formula_str: str) -> List[str]:
#         """Extract all literals from a formula, handling Not(...)."""
#         literals = []
        
#         # Find all Not(predicate(...)) patterns
#         # We need to handle the nested parentheses properly
#         i = 0
#         positions_to_remove = []  # Track what we've processed as Not
        
#         while i < len(formula_str):
#             if formula_str[i:i+4] == 'Not(':
#                 # Found a Not - now find the complete predicate inside
#                 start = i
#                 i += 4
                
#                 # Find the predicate name
#                 pred_match = re.match(r'(\w+)\(', formula_str[i:])
#                 if pred_match:
#                     pred_name = pred_match.group(1)
#                     i += len(pred_match.group(0))
                    
#                     # Now find the arguments by tracking parentheses
#                     depth = 1
#                     arg_start = i
#                     while i < len(formula_str) and depth > 0:
#                         if formula_str[i] == '(':
#                             depth += 1
#                         elif formula_str[i] == ')':
#                             depth -= 1
#                         i += 1
                    
#                     # Extract arguments
#                     args_str = formula_str[arg_start:i-1]
#                     args = [arg.strip() for arg in args_str.split(',')]
#                     pred_str = self.normalize_predicate(pred_name, args)
#                     symbol = self.get_symbol(pred_str)
#                     literals.append(f"~{symbol}")
                    
#                     # Track constants
#                     self.constants.update(args)
                    
#                     # Skip the closing paren of Not()
#                     if i < len(formula_str) and formula_str[i] == ')':
#                         i += 1
                    
#                     # Remember this range to skip when looking for positive predicates
#                     positions_to_remove.append((start, i))
#                 else:
#                     i += 1
#             else:
#                 i += 1
        
#         # Now find positive predicates, skipping the Not(...) regions
#         i = 0
#         while i < len(formula_str):
#             # Check if we're in a Not region
#             in_not_region = any(start <= i < end for start, end in positions_to_remove)
#             if in_not_region:
#                 i += 1
#                 continue
            
#             # Look for pattern: word followed by (
#             match = re.match(r'(\w+)\(', formula_str[i:])
#             if match:
#                 pred_name = match.group(1)
                
#                 # Skip logical operators
#                 if pred_name in ['Or', 'And', 'Not', 'Implies', 'ForAll']:
#                     i += len(match.group(0))
#                     continue
                
#                 # This is a real predicate - extract its arguments
#                 start = i + len(match.group(0))
#                 depth = 1
#                 j = start
                
#                 while j < len(formula_str) and depth > 0:
#                     if formula_str[j] == '(':
#                         depth += 1
#                     elif formula_str[j] == ')':
#                         depth -= 1
#                     j += 1
                
#                 # Extract arguments
#                 args_str = formula_str[start:j-1]
#                 args = [arg.strip() for arg in args_str.split(',')]
#                 pred_str = self.normalize_predicate(pred_name, args)
#                 symbol = self.get_symbol(pred_str)
#                 literals.append(symbol)
                
#                 # Track constants
#                 self.constants.update(args)
                
#                 i = j
#             else:
#                 i += 1
        
#         return literals
#     # def parse_z3_code(self, code: str):
#     #     """Parse Z3 code and extract facts, rules, and query."""
#     #     lines = code.strip().split('\n')
        
#     #     for line in lines:
#     #         line = line.strip()
#     #         if not line or line.startswith('#'):
#     #             continue
                
#     #         # Check for ForAll (rules)
#     #         if line.startswith('ForAll'):
#     #             self.parse_rule(line)
#     #         # Check for return statement (query)
#     #         elif line.startswith('return'):
#     #             query_match = re.search(r'return\s+(.+)', line)
#     #             if query_match:
#     #                 self.query = query_match.group(1).strip()
#     #         # Check for complex formulas (Or, And with predicates)
#     #         elif any(line.startswith(op) for op in ['Or(', 'And(', 'Not(']):
#     #             self.formulas.append(line)
#     #         # Otherwise it's a simple fact
#     #         else:
#     #             pred_name, args = self.parse_predicate(line)
#     #             if pred_name:
#     #                 # Normalize and store
#     #                 normalized = self.normalize_predicate(pred_name, args)
#     #                 # Parse it back to get clean args
#     #                 pred_name, args = self.parse_predicate(normalized)
#     #                 self.facts.append((pred_name, args))
#     #                 self.constants.update(args)
#     #                 if pred_name not in self.predicates:
#     #                     self.predicates[pred_name] = len(args)
#     def parse_z3_code(self, code: str):
#         """Parse Z3 code and extract facts, rules, and query."""
#         lines = code.strip().split('\n')
        
#         for line in lines:
#             line = line.strip()
#             if not line or line.startswith('#'):
#                 continue
                
#             # Check for ForAll (rules)
#             if line.startswith('ForAll'):
#                 self.parse_rule(line)
#             # Check for return statement (query)
#             elif line.startswith('return'):
#                 query_match = re.search(r'return\s+(.+)', line)
#                 if query_match:
#                     self.query = query_match.group(1).strip()
#             # Check for Implies at top level (ground implication)
#             elif line.startswith('Implies('):
#                 self.formulas.append(line)
#             # Check for complex formulas (Or, And, Not with predicates)
#             elif any(line.startswith(op) for op in ['Or(', 'And(', 'Not(']):
#                 self.formulas.append(line)
#             # Otherwise it's a simple fact
#             else:
#                 pred_name, args = self.parse_predicate(line)
#                 if pred_name:
#                     # Normalize and store
#                     normalized = self.normalize_predicate(pred_name, args)
#                     # Parse it back to get clean args
#                     pred_name, args = self.parse_predicate(normalized)
#                     self.facts.append((pred_name, args))
#                     self.constants.update(args)
#                     if pred_name not in self.predicates:
#                         self.predicates[pred_name] = len(args)

                        
#     def parse_rule(self, line: str):
#         """Parse a ForAll rule with better parsing."""
#         # Extract variables
#         vars_match = re.search(r'ForAll\(\[([^\]]+)\]', line)
#         if not vars_match:
#             return
        
#         variables = [v.strip() for v in vars_match.group(1).split(',')]
        
#         # Find the Implies or Or part
#         implies_match = re.search(r'Implies\((.*)\)\s*\)', line)
#         or_match = re.search(r'Or\((.*)\)\s*\)', line) if not implies_match else None
        
#         if implies_match:
#             self.parse_implies_rule(variables, implies_match.group(1))
#         elif or_match:
#             # ForAll([x], Or(...)) is already in clause form
#             self.parse_or_rule(variables, or_match.group(1))
    
#     def parse_or_rule(self, variables: List[str], or_body: str):
#         """Parse ForAll([x], Or(a, b, c)) - already a clause."""
#         literals_info = []
        
#         # Extract Not(...) predicates
#         not_pattern = r'Not\((\w+)\(([^)]+)\)\)'
#         for match in re.finditer(not_pattern, or_body):
#             pred_name = match.group(1)
#             args_str = match.group(2)
#             args = [arg.strip() for arg in args_str.split(',')]
#             literals_info.append((pred_name, args, True))  # True = negated
#             if pred_name not in self.predicates:
#                 self.predicates[pred_name] = len(args)
        
#         # Extract positive predicates
#         cleaned = re.sub(r'Not\([^)]+\)', '', or_body)
#         pred_pattern = r'(\w+)\(([^)]+)\)'
#         for match in re.finditer(pred_pattern, cleaned):
#             pred_name = match.group(1)
#             if pred_name in ['Or', 'And', 'Not', 'Implies']:
#                 continue
#             args_str = match.group(2)
#             args = [arg.strip() for arg in args_str.split(',')]
#             literals_info.append((pred_name, args, False))  # False = positive
#             if pred_name not in self.predicates:
#                 self.predicates[pred_name] = len(args)
        
#         # Store as a special kind of rule (disjunction)
#         self.rules.append((variables, [], literals_info))  # Empty antecedent, list of literals
    
#     def parse_implies_rule(self, variables: List[str], implies_body: str):
#         """Parse Implies(And(...), consequent) or Implies(pred, consequent)."""
#         # Split by comma at the top level
#         depth = 0
#         parts = []
#         current = []
        
#         for char in implies_body:
#             if char == '(':
#                 depth += 1
#                 current.append(char)
#             elif char == ')':
#                 depth -= 1
#                 current.append(char)
#             elif char == ',' and depth == 0:
#                 parts.append(''.join(current).strip())
#                 current = []
#             else:
#                 current.append(char)
        
#         if current:
#             parts.append(''.join(current).strip())
        
#         if len(parts) != 2:
#             return
        
#         antecedent_part = parts[0].strip()
#         consequent_part = parts[1].strip()
        
#         # Parse antecedent
#         antecedents = []
#         if antecedent_part.startswith('And('):
#             and_match = re.match(r'And\((.*)\)', antecedent_part)
#             if and_match:
#                 antecedents = self.extract_predicates_from_and(and_match.group(1))
#         else:
#             # Single predicate antecedent
#             pred_name, args = self.parse_predicate(antecedent_part)
#             if pred_name:
#                 antecedents = [(pred_name, args)]
        
#         # Track predicates
#         for pred_name, args in antecedents:
#             if pred_name not in self.predicates:
#                 self.predicates[pred_name] = len(args)
        
#         # Parse consequent
#         consequent_name, consequent_args = self.parse_predicate(consequent_part)
#         if not consequent_name:
#             return
        
#         if consequent_name not in self.predicates:
#             self.predicates[consequent_name] = len(consequent_args)
        
#         self.rules.append((variables, antecedents, (consequent_name, consequent_args)))
    
#     def ground_predicate(self, pred_name: str, args: List[str], substitution: Dict[str, str]) -> str:
#         """Ground a predicate with substitution and return normalized form."""
#         grounded_args = [substitution.get(arg, arg) for arg in args]
#         return self.normalize_predicate(pred_name, grounded_args)
    
#     def ground_rules(self) -> List[List[str]]:
#         """Ground all rules with all possible constant substitutions."""
#         clauses = []
#         constants_list = list(self.constants)
        
#         for variables, antecedents, consequent in self.rules:
#             # Generate all possible substitutions
#             for substitution_values in product(constants_list, repeat=len(variables)):
#                 substitution = dict(zip(variables, substitution_values))
                
#                 # Check if this is a disjunction rule (empty antecedents, consequent is list)
#                 if len(antecedents) == 0 and isinstance(consequent, list):
#                     # This is ForAll([x], Or(...)) form
#                     clause = []
#                     for pred_name, args, is_negated in consequent:
#                         grounded = self.ground_predicate(pred_name, args, substitution)
#                         symbol = self.get_symbol(grounded)
#                         if is_negated:
#                             clause.append(f"~{symbol}")
#                         else:
#                             clause.append(symbol)
#                     clauses.append(clause)
#                 else:
#                     # Standard Implies rule: ~A1 v ~A2 v ... v C
#                     clause = []
                    
#                     # Add negated antecedents
#                     for pred, args in antecedents:
#                         grounded = self.ground_predicate(pred, args, substitution)
#                         clause.append(f"~{self.get_symbol(grounded)}")
                    
#                     # Add consequent
#                     grounded_consequent = self.ground_predicate(consequent[0], consequent[1], substitution)
#                     clause.append(self.get_symbol(grounded_consequent))
                    
#                     clauses.append(clause)
        
#         return clauses
    
#     def convert_to_cnf(self, negate_query: bool = False) -> Tuple[List[List[str]], Dict[str, str]]:
#         """
#         Convert the FOL problem to CNF.
        
#         Args:
#             negate_query: If True, negate the query (to prove it's true via contradiction)
#                          If False, assert the query (to prove it's false via contradiction)
        
#         Returns:
#             Tuple of (clauses, mappings)
#         """
#         clauses = []
        
#         # Add facts as unit clauses
#         for pred_name, args in self.facts:
#             fact_str = self.normalize_predicate(pred_name, args)
#             symbol = self.get_symbol(fact_str)
#             clauses.append([symbol])
        
#         # Add grounded rules
#         clauses.extend(self.ground_rules())
        
#         # Add non-quantified formulas converted to CNF
#         for formula in self.formulas:
#             formula_cnf = self.parse_formula_to_cnf(formula)
#             clauses.extend(formula_cnf)
        
#         # Add query (negated for proving true, asserted for proving false)
#         if self.query:
#             # Convert query to CNF
#             query_cnf = self.parse_formula_to_cnf(self.query)
#             # breakpoint()
#             if negate_query:
#                 query_cnf = self.parse_formula_to_cnf('Not(' + self.query + ')')
#                 # Negate the entire query
#                 # If query is a single literal, just negate it
#             #     if len(query_cnf) == 1 and len(query_cnf[0]) == 1:
#             #         lit = query_cnf[0][0]
#             #         if lit.startswith('~'):
#             #             clauses.append([lit[1:]])
#             #         else:
#             #             clauses.append([f"~{lit}"])
#             #     else:
#             #         # For complex queries, we need to negate the whole formula
#             #         # This is a simplification - proper negation requires DeMorgan's laws
#             #         for clause in query_cnf:
#             #             negated_clause = []
#             #             for lit in clause:
#             #                 if lit.startswith('~'):
#             #                     negated_clause.append(lit[1:])
#             #                 else:
#             #                     negated_clause.append(f"~{lit}")
#             #             clauses.append(negated_clause)
#             # else:
#             #     clauses.extend(query_cnf)
#             clauses.extend(query_cnf)
        
#         return clauses, dict(self.mappings)
import re
from typing import List, Dict, Tuple, Set, Any
from itertools import product

class FOLToCNF:
    def __init__(self):
        self.symbol_counter = 0
        self.mappings = {}
        self.reverse_mappings = {}
        self.facts = []
        self.rules = []
        self.formulas = []  # Non-quantified formulas (Or, And, Not combinations)
        self.query = None
        self.constants = set()
        self.predicates = {}
        
    def normalize_predicate(self, pred_name: str, args: List[str]) -> str:
        """Normalize a predicate to a consistent string format (no spaces after commas)."""
        return f"{pred_name}({','.join(args)})"
    
    def get_symbol(self, name: str) -> str:
        """Get or create a propositional symbol for a ground predicate."""
        try:
            # Normalize the name to ensure consistency
            name = name.replace(', ', ',').replace(' ,', ',')
            
            # Validate that this is actually a predicate, not a formula
            # If it's a formula, return a placeholder symbol and continue
            if any(name.startswith(op) for op in ['Implies(', 'ForAll(', 'Exists(', 'Or(', 'And(']):
                # Create a special symbol for malformed input
                if name not in self.mappings:
                    symbol = f"p{self.symbol_counter}"
                    self.symbol_counter += 1
                    self.mappings[name] = symbol
                    self.reverse_mappings[symbol] = name
                return self.mappings[name]
            
            if name not in self.mappings:
                symbol = f"p{self.symbol_counter}"
                self.symbol_counter += 1
                self.mappings[name] = symbol
                self.reverse_mappings[symbol] = name
            return self.mappings[name]
        except Exception:
            # If anything goes wrong, return a default symbol
            return "p_error"
    
    def parse_predicate(self, pred_str: str) -> Tuple[str, List[str]]:
        """Parse a predicate string like 'previous_stop(x, y)' into (name, [args])."""
        try:
            pred_str = pred_str.strip()
            
            # Reject formulas - they should not be parsed as predicates
            if any(pred_str.startswith(op) for op in ['Implies(', 'ForAll(', 'Exists(', 'Or(', 'And(', 'Not(']):
                return None, []
            
            match = re.match(r'(\w+)\(([^)]+)\)', pred_str)
            if match:
                pred_name = match.group(1)
                
                # Validate that pred_name is not a logical operator
                if pred_name in ['Implies', 'ForAll', 'Exists', 'Or', 'And', 'Not']:
                    return None, []
                
                args = [arg.strip() for arg in match.group(2).split(',')]
                return pred_name, args
            return None, []
        except Exception:
            return None, []
    
    def extract_predicates_from_and(self, and_content: str) -> List[Tuple[str, List[str]]]:
        """Extract all predicates from And(...) content."""
        predicates = []
        pattern = r'(\w+)\(([^)]+)\)'
        for match in re.finditer(pattern, and_content):
            pred_name = match.group(1)
            args_str = match.group(2)
            args = [arg.strip() for arg in args_str.split(',')]
            predicates.append((pred_name, args))
        return predicates
    
    def extract_top_level_args(self, content: str) -> List[str]:
        """
        Extract top-level comma-separated arguments, respecting parentheses nesting.
        E.g., "a, Or(b, c), d" → ["a", "Or(b, c)", "d"]
        """
        try:
            args = []
            current = []
            depth = 0
            
            for char in content:
                if char == '(':
                    depth += 1
                    current.append(char)
                elif char == ')':
                    depth -= 1
                    current.append(char)
                elif char == ',' and depth == 0:
                    args.append(''.join(current).strip())
                    current = []
                else:
                    current.append(char)
            
            if current:
                args.append(''.join(current).strip())
            
            return args
        except Exception:
            # If parsing fails, return the whole content as a single arg
            return [content] if content else []
    
    def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
        """
        Parse a propositional formula and convert it to CNF.
        Handles Or, And, Not, Implies combinations of ground predicates.
        """
        try:
            formula_str = formula_str.strip()
            
            if not formula_str:
                return []
            
            # Handle Exists - ground with all constants
            if formula_str.startswith('Exists:'):
                return self.parse_exists_formula(formula_str)
            
            # Handle Implies - convert to Or(Not(A), B)
            if formula_str.startswith('Implies('):
                return self.parse_implies_to_cnf(formula_str)
            
            # Handle And - need to check if it contains nested Or/And
            if formula_str.startswith('And('):
                return self.parse_and_to_cnf(formula_str)
            
            # Handle Or - single clause or needs distribution
            if formula_str.startswith('Or('):
                return self.parse_or_to_cnf(formula_str)
            
            # Handle Not
            if formula_str.startswith('Not('):
                return self.parse_not_to_cnf(formula_str)
            
            # Handle single predicate (no logical operator)
            pred_name, args = self.parse_predicate(formula_str)
            if pred_name:
                pred_str = self.normalize_predicate(pred_name, args)
                symbol = self.get_symbol(pred_str)
                self.constants.update(args)
                return [[symbol]]
            
            return []
        except Exception:
            # If parsing fails completely, return empty
            return []
    
    def parse_implies_to_cnf(self, formula_str: str) -> List[List[str]]:
        """
        Parse Implies(A, B) and convert to CNF.
        Implies(A, B) ≡ Or(Not(A), B) ≡ ~A ∨ B
        """
        try:
            # Extract the content inside Implies(...)
            if len(formula_str) < 10:  # "Implies()" is 10 chars
                return []
            
            inner = formula_str[8:-1]  # Remove 'Implies(' and ')'
            args = self.extract_top_level_args(inner)
            
            if len(args) != 2:
                return []
            
            antecedent = args[0].strip()
            consequent = args[1].strip()
            
            # Convert: Implies(A, B) → Or(Not(A), B)
            # Build the Or formula
            or_formula = f"Or(Not({antecedent}), {consequent})"
            
            # Parse it as an Or
            return self.parse_or_to_cnf(or_formula)
        except Exception:
            return []
    
    def parse_and_to_cnf(self, formula_str: str) -> List[List[str]]:
        """
        Parse And(...) and convert to CNF.
        And is already a conjunction, so we convert each sub-formula to CNF
        and combine all clauses.
        """
        try:
            if len(formula_str) < 6:  # "And()" is 5 chars
                return []
            
            # Extract all top-level arguments of And
            inner = formula_str[4:-1]  # Remove 'And(' and ')'
            sub_formulas = self.extract_top_level_args(inner)
            
            all_clauses = []
            for sub in sub_formulas:
                sub_cnf = self.parse_formula_to_cnf(sub)
                all_clauses.extend(sub_cnf)
            
            return all_clauses
        except Exception:
            return []
    
    def parse_or_to_cnf(self, formula_str: str) -> List[List[str]]:
        """
        Parse Or(...) and convert to CNF.
        Or is a disjunction - need to distribute if it contains conjunctions.
        """
        try:
            if len(formula_str) < 5:  # "Or()" is 4 chars
                return []
            
            # Extract all top-level arguments of Or
            inner = formula_str[3:-1]  # Remove 'Or(' and ')'
            sub_formulas = self.extract_top_level_args(inner)
            
            # Convert each sub-formula to CNF
            sub_cnfs = []
            for sub in sub_formulas:
                sub_cnf = self.parse_formula_to_cnf(sub)
                if not sub_cnf:
                    sub_cnf = [[]]  # Empty disjunct
                sub_cnfs.append(sub_cnf)
            
            if len(sub_cnfs) == 0:
                return []
            
            # Distribute: Or of CNFs
            # Start with first CNF
            result = sub_cnfs[0]
            
            # Distribute with each subsequent CNF
            for cnf in sub_cnfs[1:]:
                new_result = []
                for clause1 in result:
                    for clause2 in cnf:
                        new_result.append(clause1 + clause2)
                result = new_result
            
            return result
        except Exception:
            return []
    
    def parse_not_to_cnf(self, formula_str: str) -> List[List[str]]:
        """Parse Not(...) and apply De Morgan's laws if needed."""
        try:
            if len(formula_str) < 6:  # "Not()" is 5 chars
                return []
            
            # Extract what's inside the Not()
            inner = formula_str[4:-1]  # Remove 'Not(' and ')'
            inner = inner.strip()
            
            if not inner:
                return []
            
            # Apply De Morgan's laws
            # Not(And(A, B, ...)) → Or(Not(A), Not(B), ...)
            if inner.startswith('And('):
                and_inner = inner[4:-1]  # Remove 'And(' and ')'
                sub_formulas = self.extract_top_level_args(and_inner)
                # Negate each and create Or
                negated = [f"Not({sub})" for sub in sub_formulas]
                or_formula = f"Or({', '.join(negated)})"
                return self.parse_or_to_cnf(or_formula)
            
            # Not(Or(A, B, ...)) → And(Not(A), Not(B), ...)
            elif inner.startswith('Or('):
                or_inner = inner[3:-1]  # Remove 'Or(' and ')'
                sub_formulas = self.extract_top_level_args(or_inner)
                # Negate each and create And
                all_clauses = []
                for sub in sub_formulas:
                    not_sub = f"Not({sub})"
                    clauses = self.parse_formula_to_cnf(not_sub)
                    all_clauses.extend(clauses)
                return all_clauses
            
            # Not(Not(A)) → A (double negation elimination)
            elif inner.startswith('Not('):
                double_neg_inner = inner[4:-1]  # Remove inner 'Not(' and ')'
                return self.parse_formula_to_cnf(double_neg_inner)
            
            # Not(Implies(A, B)) → Not(Or(Not(A), B)) → And(Not(Not(A)), Not(B)) → And(A, Not(B))
            elif inner.startswith('Implies('):
                implies_inner = inner[8:-1]
                args = self.extract_top_level_args(implies_inner)
                if len(args) == 2:
                    # Not(Implies(A, B)) ≡ A ∧ Not(B)
                    and_formula = f"And({args[0]}, Not({args[1]}))"
                    return self.parse_and_to_cnf(and_formula)
            
            # Not(predicate(...)) → single negated literal
            else:
                pred_name, args = self.parse_predicate(inner)
                if pred_name:
                    pred_str = self.normalize_predicate(pred_name, args)
                    symbol = self.get_symbol(pred_str)
                    self.constants.update(args)
                    return [[f"~{symbol}"]]
            
            return []
        except Exception:
            return []
    
    def parse_exists_formula(self, formula_str: str) -> List[List[str]]:
        """
        Parse Exists formula: Exists:num_vars:body
        Ground it with all constants and create a disjunction.
        Exists([x], P(x)) ≡ P(c1) ∨ P(c2) ∨ ...
        Exists([x], And(P(x), Q(x))) ≡ And(P(c1), Q(c1)) ∨ And(P(c2), Q(c2)) ∨ ...
        """
        parts = formula_str.split(':', 2)
        if len(parts) != 3:
            return []
        
        num_vars = int(parts[1])
        body = parts[2]
        
        # Get constants to ground with
        constants_list = list(self.constants) if self.constants else ['_skolem_constant']
        
        # For a single variable (most common case)
        if num_vars == 1:
            # Create grounded versions of the body for each constant
            grounded_bodies = []
            for const in constants_list:
                # Replace variable 'x' with constant in the body
                # This is a simple string replacement approach
                grounded_body = body.replace('(x)', f'({const})').replace('(x,', f'({const},').replace(',x)', f',{const})')
                grounded_bodies.append(grounded_body)
            
            # Now we have Exists([x], body) ≡ body[x:=c1] ∨ body[x:=c2] ∨ ...
            # Create an Or of all grounded bodies
            if len(grounded_bodies) == 1:
                # Just one constant, convert body directly
                return self.parse_formula_to_cnf(grounded_bodies[0])
            else:
                # Multiple constants, create disjunction
                or_formula = f"Or({', '.join(grounded_bodies)})"
                return self.parse_or_to_cnf(or_formula)
        
        # For multiple variables, we'd need to handle all combinations
        # For now, just return empty (this is a simplification)
        return []
    
    def parse_z3_code(self, code: str):
        """Parse Z3 code and extract facts, rules, and query."""
        lines = code.strip().split('\n')
        
        for line in lines:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
                
            # Check for ForAll (rules)
            if line.startswith('ForAll'):
                self.parse_rule(line)
            # Check for Exists - treat as a formula (will be grounded later)
            elif line.startswith('Exists'):
                self.parse_exists(line)
            # Check for return statement (query)
            elif line.startswith('return '):
                query_match = re.search(r'return\s+(.+)', line)
                if query_match:
                    self.query = query_match.group(1).strip()
            elif line.startswith('return('):
                query_match = re.search(r'return+\((.+)\)', line)
                if query_match:
                    self.query = query_match.group(1).strip()
            # Check for Implies at top level (ground implication)
            elif line.startswith('Implies('):
                self.formulas.append(line)
            # Check for complex formulas (Or, And, Not with predicates)
            elif any(line.startswith(op) for op in ['Or(', 'And(', 'Not(']):
                self.formulas.append(line)
            # Otherwise it's a simple fact
            else:
                pred_name, args = self.parse_predicate(line)
                if pred_name:
                    # Normalize and store
                    normalized = self.normalize_predicate(pred_name, args)
                    # Parse it back to get clean args
                    pred_name, args = self.parse_predicate(normalized)
                    self.facts.append((pred_name, args))
                    self.constants.update(args)
                    if pred_name not in self.predicates:
                        self.predicates[pred_name] = len(args)
    
    def parse_exists(self, line: str):
        """
        Parse an Exists statement.
        Exists([x], formula) means there exists at least one x that satisfies the formula.
        In CNF conversion, we can either:
        1. Skolemize (replace with a constant)
        2. Ground with all constants (instantiate for all possible values)
        
        For simplicity, we'll ground it with all known constants.
        """
        # Extract variables
        vars_match = re.search(r'Exists\(\[([^\]]+)\]', line)
        if not vars_match:
            return
        
        variables = [v.strip() for v in vars_match.group(1).split(',')]
        
        # Extract the body formula
        # Find the comma after the variable list
        bracket_end = line.find(']')
        if bracket_end == -1:
            return
        
        # Find the content after the comma
        rest = line[bracket_end+1:].strip()
        if rest.startswith(','):
            rest = rest[1:].strip()
        
        # Remove the final closing paren
        if rest.endswith(')'):
            rest = rest[:-1].strip()
        
        # Store as a special existential formula
        # We'll treat it as a disjunction over all groundings
        # Exists([x], P(x)) ≡ P(c1) ∨ P(c2) ∨ ... for all constants
        self.formulas.append(f"Exists:{len(variables)}:{rest}")
    
    def parse_rule(self, line: str):
        """Parse a ForAll rule with better parsing."""
        # Extract variables
        vars_match = re.search(r'ForAll\(\[([^\]]+)\]', line)
        if not vars_match:
            return
        
        variables = [v.strip() for v in vars_match.group(1).split(',')]
        
        # Find the Implies or Or part
        implies_match = re.search(r'Implies\((.*)\)\s*\)', line)
        or_match = re.search(r'Or\((.*)\)\s*\)', line) if not implies_match else None
        
        if implies_match:
            self.parse_implies_rule(variables, implies_match.group(1))
        elif or_match:
            # ForAll([x], Or(...)) is already in clause form
            self.parse_or_rule(variables, or_match.group(1))
    
    def parse_or_rule(self, variables: List[str], or_body: str):
        """
        Parse ForAll([x], Or(a, b, c)) - already a clause.
        But if the Or contains complex formulas, we need to convert to CNF first.
        """
        # Check if the or_body contains nested complex formulas (And, Or, Implies)
        # If so, we need to handle it differently
        
        # Extract top-level arguments of the Or
        args = self.extract_top_level_args(or_body)
        
        # Check if any argument is a complex formula
        has_complex = any(
            arg.strip().startswith(('And(', 'Or(', 'Implies('))
            for arg in args
        )
        
        if has_complex:
            # This is a complex rule that needs to be converted via CNF
            # Store it as a special rule that will be handled during grounding
            self.rules.append((variables, [], ('COMPLEX_OR', or_body)))
            return
        
        # Simple case: all arguments are literals (predicates or Not(predicates))
        literals_info = []
        
        for arg in args:
            arg = arg.strip()
            
            # Check if it's a negated predicate
            if arg.startswith('Not('):
                # Extract the predicate inside Not(...)
                inner = arg[4:-1]  # Remove 'Not(' and ')'
                pred_name, pred_args = self.parse_predicate(inner)
                if pred_name:
                    literals_info.append((pred_name, pred_args, True))  # True = negated
                    if pred_name not in self.predicates:
                        self.predicates[pred_name] = len(pred_args)
            else:
                # Positive predicate
                pred_name, pred_args = self.parse_predicate(arg)
                if pred_name:
                    literals_info.append((pred_name, pred_args, False))  # False = positive
                    if pred_name not in self.predicates:
                        self.predicates[pred_name] = len(pred_args)
        
        # Store as a special kind of rule (disjunction)
        self.rules.append((variables, [], literals_info))  # Empty antecedent, list of literals
    
    def parse_implies_rule(self, variables: List[str], implies_body: str):
        """Parse Implies(And(...), consequent) or Implies(pred, consequent)."""
        # Split by comma at the top level
        depth = 0
        parts = []
        current = []
        
        for char in implies_body:
            if char == '(':
                depth += 1
                current.append(char)
            elif char == ')':
                depth -= 1
                current.append(char)
            elif char == ',' and depth == 0:
                parts.append(''.join(current).strip())
                current = []
            else:
                current.append(char)
        
        if current:
            parts.append(''.join(current).strip())
        
        if len(parts) != 2:
            return
        
        antecedent_part = parts[0].strip()
        consequent_part = parts[1].strip()
        
        # Check if antecedent or consequent is complex
        # If so, store the entire rule for later processing during grounding
        has_complex_antecedent = any(
            antecedent_part.startswith(op) for op in ['Or(', 'Implies(', 'Not(And', 'Not(Or']
        )
        has_complex_consequent = consequent_part.startswith(('Or(', 'And('))
        
        if has_complex_antecedent or has_complex_consequent:
            # Store as complex implies rule
            self.rules.append((variables, [], ('COMPLEX_IMPLIES', implies_body)))
            return
        
        # Parse antecedent
        antecedents = []
        if antecedent_part.startswith('And('):
            and_match = re.match(r'And\((.*)\)', antecedent_part)
            if and_match:
                antecedents = self.extract_predicates_from_and(and_match.group(1))
        else:
            # Single predicate antecedent
            pred_name, args = self.parse_predicate(antecedent_part)
            if pred_name:
                antecedents = [(pred_name, args)]
        
        # Track predicates
        for pred_name, args in antecedents:
            if pred_name not in self.predicates:
                self.predicates[pred_name] = len(args)
        
        # Parse consequent - check if it's a complex formula
        if consequent_part.startswith('Or('):
            # Implies(A, Or(B, C)) ≡ ~A ∨ B ∨ C
            # This is already in clause form!
            # Extract all literals from the Or
            or_literals = []
            
            # Extract the content of Or(...)
            or_match = re.match(r'Or\((.*)\)', consequent_part)
            if or_match:
                or_content = or_match.group(1)
                or_predicates = self.extract_predicates_from_and(or_content)  # Reuse this helper
                
                for pred_name, args in or_predicates:
                    or_literals.append((pred_name, args, False))  # False = positive
                    if pred_name not in self.predicates:
                        self.predicates[pred_name] = len(args)
                
                # Store as: antecedents lead to a disjunction
                # We'll convert: Implies(A, Or(B, C)) to Or(~A, B, C)
                # So store all antecedents as negated, plus all consequent literals as positive
                all_literals = []
                for pred_name, args in antecedents:
                    all_literals.append((pred_name, args, True))  # True = negated (antecedent)
                all_literals.extend(or_literals)  # Add consequent literals (positive)
                
                self.rules.append((variables, [], all_literals))  # Store as disjunction form
                return
        
        elif consequent_part.startswith('And('):
            # Implies(A, And(B, C)) ≡ (~A ∨ B) ∧ (~A ∨ C)
            # Need to create multiple clauses
            and_match = re.match(r'And\((.*)\)', consequent_part)
            if and_match:
                and_predicates = self.extract_predicates_from_and(and_match.group(1))
                
                # Create one clause for each consequent literal
                for cons_pred_name, cons_args in and_predicates:
                    if cons_pred_name not in self.predicates:
                        self.predicates[cons_pred_name] = len(cons_args)
                    
                    # Each clause: ~A1 ∨ ~A2 ∨ ... ∨ Ci
                    self.rules.append((variables, antecedents, (cons_pred_name, cons_args)))
                return
        
        # Check for Not in consequent
        if consequent_part.startswith('Not('):
            # Extract what's inside Not(...)
            not_match = re.match(r'Not\((\w+)\(([^)]+)\)\)', consequent_part)
            if not_match:
                cons_pred_name = not_match.group(1)
                cons_args_str = not_match.group(2)
                cons_args = [arg.strip() for arg in cons_args_str.split(',')]
                
                if cons_pred_name not in self.predicates:
                    self.predicates[cons_pred_name] = len(cons_args)
                
                # Store with negated consequent marker
                # We'll handle this specially in grounding
                self.rules.append((variables, antecedents, (cons_pred_name, cons_args, True)))
                return
        
        # Simple consequent (single predicate)
        consequent_name, consequent_args = self.parse_predicate(consequent_part)
        if not consequent_name:
            return
        
        if consequent_name not in self.predicates:
            self.predicates[consequent_name] = len(consequent_args)
        
        self.rules.append((variables, antecedents, (consequent_name, consequent_args)))
    
    def ground_predicate(self, pred_name: str, args: List[str], substitution: Dict[str, str]) -> str:
        """Ground a predicate with substitution and return normalized form."""
        grounded_args = [substitution.get(arg, arg) for arg in args]
        return self.normalize_predicate(pred_name, grounded_args)
    
    def ground_rules(self) -> List[List[str]]:
        """Ground all rules with all possible constant substitutions."""
        clauses = []
        constants_list = list(self.constants)
        
        for variables, antecedents, consequent in self.rules:
            # Generate all possible substitutions
            for substitution_values in product(constants_list, repeat=len(variables)):
                substitution = dict(zip(variables, substitution_values))
                
                # Check if this is a complex OR rule
                if isinstance(consequent, tuple) and len(consequent) == 2 and consequent[0] == 'COMPLEX_OR':
                    # Ground the complex Or formula and convert to CNF
                    or_body = consequent[1]
                    
                    # Replace variables with constants in the formula
                    grounded_formula = or_body
                    for var, const in substitution.items():
                        # Strip whitespace from variable name
                        var = var.strip()
                        # Use regex with word boundaries for precise replacement
                        # This handles: (x), (x,, ,x), ,x,, etc.
                        grounded_formula = re.sub(
                            rf'\({var}\)',  # (x) -> (const)
                            f'({const})',
                            grounded_formula
                        )
                        grounded_formula = re.sub(
                            rf'\({var},',  # (x, -> (const,
                            f'({const},',
                            grounded_formula
                        )
                        grounded_formula = re.sub(
                            rf',{var}\)',  # ,x) -> ,const)
                            f',{const})',
                            grounded_formula
                        )
                        grounded_formula = re.sub(
                            rf',{var},',  # ,x, -> ,const,
                            f',{const},',
                            grounded_formula
                        )
                    
                    # Now convert this grounded Or formula to CNF
                    or_formula = f"Or({grounded_formula})"
                    cnf_clauses = self.parse_or_to_cnf(or_formula)
                    clauses.extend(cnf_clauses)
                    continue
                
                # Check if this is a disjunction rule (empty antecedents, consequent is list)
                if len(antecedents) == 0 and isinstance(consequent, list):
                    # This is ForAll([x], Or(...)) form with simple literals
                    clause = []
                    for pred_name, args, is_negated in consequent:
                        grounded = self.ground_predicate(pred_name, args, substitution)
                        symbol = self.get_symbol(grounded)
                        if is_negated:
                            clause.append(f"~{symbol}")
                        else:
                            clause.append(symbol)
                    clauses.append(clause)
                else:
                    # Standard Implies rule: ~A1 v ~A2 v ... v C
                    clause = []
                    
                    # Add negated antecedents
                    for pred, args in antecedents:
                        grounded = self.ground_predicate(pred, args, substitution)
                        clause.append(f"~{self.get_symbol(grounded)}")
                    
                    # Check if consequent is negated (tuple with 3 elements)
                    if isinstance(consequent, tuple) and len(consequent) == 3:
                        # Negated consequent
                        grounded_consequent = self.ground_predicate(consequent[0], consequent[1], substitution)
                        clause.append(f"~{self.get_symbol(grounded_consequent)}")
                    else:
                        # Positive consequent
                        grounded_consequent = self.ground_predicate(consequent[0], consequent[1], substitution)
                        clause.append(self.get_symbol(grounded_consequent))
                    
                    clauses.append(clause)
        
        return clauses
    
    def convert_to_cnf(self, negate_query: bool = False) -> Tuple[List[List[str]], Dict[str, str]]:
        """
        Convert the FOL problem to CNF.
        
        Args:
            negate_query: If True, negate the query (to prove it's true via contradiction)
                         If False, assert the query (to prove it's false via contradiction)
        
        Returns:
            Tuple of (clauses, mappings)
        """
        clauses = []
        
        # Add facts as unit clauses
        for pred_name, args in self.facts:
            fact_str = self.normalize_predicate(pred_name, args)
            symbol = self.get_symbol(fact_str)
            clauses.append([symbol])
        
        # Add grounded rules
        clauses.extend(self.ground_rules())
        
        # Add non-quantified formulas converted to CNF
        for formula in self.formulas:
            formula_cnf = self.parse_formula_to_cnf(formula)
            clauses.extend(formula_cnf)
        
        # Add query (negated for proving true, asserted for proving false)
        if self.query:
            # Convert query to CNF
            query_cnf = self.parse_formula_to_cnf(self.query)
            
            if negate_query:
                # Negate the entire query
                # If query is a single literal, just negate it
                if len(query_cnf) == 1 and len(query_cnf[0]) == 1:
                    lit = query_cnf[0][0]
                    if lit.startswith('~'):
                        clauses.append([lit[1:]])
                    else:
                        clauses.append([f"~{lit}"])
                else:
                    # For complex queries, we need to negate the whole formula
                    # Convert query back to formula, wrap in Not(), then convert to CNF
                    negated_query = f"Not({self.query})"
                    negated_cnf = self.parse_formula_to_cnf(negated_query)
                    clauses.extend(negated_cnf)
            else:
                clauses.extend(query_cnf)
        
        return clauses, dict(self.mappings)

        
def write_dimacs_cnf(clauses: List[List[str]], filename: str):
    """
    Write CNF clauses to a DIMACS format file.
    DIMACS format:
    - Comments start with 'c'
    - Problem line: p cnf <num_vars> <num_clauses>
    - Each clause is a line of space-separated integers ending with 0
    - Positive integer n represents variable n
    - Negative integer -n represents negation of variable n
    """
    # Create a mapping from propositional symbols (p0, p1, ...) to integers (1, 2, ...)
    vs = set()
    new_clauses = []
    for clause in clauses:
        if clause == []: continue
        else: new_clauses.append(clause)
        for l in clause:
            if l.startswith('~'):
                l = l[1:]
            l = l.strip('p')
            vs.add(int(l))
    clauses = new_clauses
    # if 0 in vs: breakpoint()
    if len(vs) > 0:
        num_vars = np.max(list(vs)) + 1
    else: 
        num_vars = 1
    num_clauses = len(clauses)
    with open(filename, 'w') as f:
        # Write problem line
        f.write(f"p cnf {num_vars} {num_clauses}\n")
        
        # Write clauses
        for clause in clauses:
            if clause == []: continue
            dimacs_clause = []
            for literal in clause:
                if literal.startswith('~'):
                    # Negated literal
                    symbol = literal[1:]
                    dimacs_clause.append(-(int(symbol.strip('p'))+1))
                else:
                    # Positive literal
                    dimacs_clause.append(int(literal.strip('p'))+1)
            
            # Write clause as space-separated integers ending with 0
            f.write(' '.join(map(str, dimacs_clause)) + ' 0\n')

def write_mapping_file(mappings: Dict[str, str], filename: str):
    """
    Write the mappings from propositional symbols to predicates.
    Format: symbol -> predicate
    """
    with open(filename, 'w') as f:
        # Sort by symbol number for readability
        sorted_mappings = sorted(mappings.items(), key=lambda x: int(x[1].strip('p')))
        for predicate, symbol in sorted_mappings:
            f.write(f"{symbol} -> {predicate}\n")

def fol_to_cnf(z3_code: str) -> Dict:
    """
    Convert a Z3 FOL problem to CNF.
    
    Returns a dictionary with:
        - cnf_prove_true: CNF to test if query is True (query negated)
        - cnf_prove_false: CNF to test if query is False (query asserted)
        - mappings: Dictionary mapping symbols to predicates
        - reverse_mappings: Dictionary mapping predicates to symbols
    """
    converter = FOLToCNF()
    converter.parse_z3_code(z3_code)
    
    # print(f"\nParsed {len(converter.facts)} facts")
    # print(f"Parsed {len(converter.rules)} rules")
    # print(f"Parsed {len(converter.formulas)} formulas")
    # print(f"Query: {converter.query}")
    
    # Generate CNF for proving query is TRUE (negate query, look for contradiction)
    cnf_true, mappings = converter.convert_to_cnf(negate_query=True)
    
    # Generate CNF for proving query is FALSE (assert query, look for contradiction)
    cnf_false, _ = converter.convert_to_cnf(negate_query=False)
    
    return {
        'pos': cnf_true,
        'neg': cnf_false,
        'mappings': mappings,
        'reverse_mappings': converter.reverse_mappings
    }


def add_clause(file):
    f = open(file, 'r')
    lines = f.readlines()
    writestr = ''
    for line in lines:
        if line.startswith('p cnf'):
            num_var, num_clause = [line.split(' ')[2], line.split(' ')[3]]
            writestr += 'p cnf ' + str(num_var) + ' ' + str(int(num_clause) + 1) + '\n'
        else:
            writestr += line
    f.close()

    f = open(file, 'w')
    f.write(writestr)
    f.close()
def add_var(file):
    f = open(file, 'r')
    lines = f.readlines()
    writestr = ''
    for line in lines:
        if line.startswith('p cnf'):
            num_var, num_clause = [line.split(' ')[2], line.split(' ')[3]]
            writestr += 'p cnf ' + str(int(num_var) +1) + ' ' + str(int(num_clause) ) + '\n'
        else:
            writestr += line
    f.close()

    f = open(file, 'w')
    f.write(writestr)
    f.close()


def check_pattern(van, vap, vbn, vbp, vcn, vcp, patterns):
    mapping = {}
    if not isinstance(van, list):
        van = list(van)

    if not isinstance(vbn, list):
        vbn = list(vbn)

    if not isinstance(vcn, list):
        vcn = list(vcn)

    names = van + vbn + vcn
    i = 0
    for name in names:
        if name not in mapping.keys():
            mapping[name] = str(i)
            i += 1

    pattern = ([mapping[name] for name in van], vap, [mapping[name] for name in vbn], vbp, [mapping[name] for name in vcn], vcp)

    for pat in patterns:
        if pat == pattern:
            return True
        
    return False

def search_pattern(van, vap, vbn, vbp, patterns):
    mapping = {}
    if not isinstance(van, list):
        van = list(van)

    if not isinstance(vbn, list):
        vbn = list(vbn)


    names = van + vbn
    unmapping = {}
    i = 0
    for name in names:
        if name not in mapping.keys():
            mapping[name] = str(i)
            unmapping[str(i)] = name
            i += 1
    pattern = ([mapping[name] for name in van], vap, [mapping[name] for name in vbn], vbp)

    for pat in patterns:
        ant = (pat[0], pat[1], pat[2], pat[3])
        if ant == pattern:
            # breakpoint()
            try:
                return ([unmapping[name] for name in pat[4]], pat[5])
            except:
                continue
    return None

                
import math
def get_sol(file, lim=10000, del_sols=None, seedrun=0):
    
    solutions = {'pos':  [], 'neg': []}
    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]
    for i in range(len(files)):
        file = files[i]
        shutil.copy(file, '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
        #   
        if not del_sols==None:
            if 'pos' in file:
                if 'neg' in file:
                    print('l. 343 uh oh')
                      
                ds = del_sols['pos']
            elif 'neg' in file:
                ds = del_sols['neg']
            for sol in ds:
                add_clause('/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/'+ str(file.split('/')[-1]))
                cf = open('/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/'+ str(file.split('/')[-1]), 'a')
                write_str = '\n'
                for lit in sol:
                    write_str += str(-lit) + ' '
                # write_str += '0'
                cf.write(write_str)
                cf.close()
        count = 0
        while True:
            count += 1
            if count > lim:
                break
            os.system(USER_PATH + '/sat_gen/HardSATGEN/postprocess/cadical/build/cadical ' + '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]) + '> ' + '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.log')
            
            cf = open( '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.log', 'r')

            lines = cf.readlines()

            el = lines[-1]
            # print(el)
            try:
                ec = el.split('exit ')[1].strip('\n')
            except:
                breakpoint()
            # lf.close()
            if ec == '20':
                break
            sl = lines[1:]
            while not sl[0].startswith('s '):
                sl = sl[1:]
            sl = sl[1:]
            solution = []
            while sl[0].startswith('v '):
                solution += list(map(int, sl[0].strip('\n').split(' ')[1:]))
                sl = sl[1:]
            if 'pos' in file:
                if 'neg' in file:
                    print('l. 384 uh oh')
                      
                solutions['pos'].append(solution)
            elif 'neg' in file:
                solutions['neg'].append(solution)
            cf.close()
            #   
            add_clause('/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
            cf = open('/'.join(file.split('/')[:-2]) +'/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]), 'a')
            write_str = '\n'
            for lit in solution:
                write_str += str(-lit) + ' '
            # write_str += '0'
            cf.write(write_str)
            cf.close()
            #   
        # solutions['pos'] = np.stack(solutions['pos'])
        # solutions['neg'] = np.stack(solutions['neg'])
    

    return solutions
    
def score(context, step, mode = 'conf', cheat_idx = None):
    ftc = fol_to_cnf(step)
    if ftc['pos'] == []: 
        return -1000000
    mapping = ftc['mappings']
    for key in mapping.keys():
            # print(key, len(key.split('(')))
            if len(key.split('(')) != 2:
                return -1000000
    mc_t = get_mc(context + step)
    mc_tminus = get_mc(context)

    if mode == 'conf':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else: mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))

        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: mc_tminus = np.max(list(mc_tminus.values()))/np.sum(list(mc_tminus.values()))
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_t - mc_tminus
    elif mode == 'conf-bb-norm':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        bb = get_bb(context + step)
        bb_tminus = get_bb(context)

        # n_var = len(fol_to_cnf(context + step)['mappings'])
        result = fol_to_cnf(context+step)
        result_tminus = fol_to_cnf(context)
        n_var = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var.append(np.max(list(vs)) + 1)
        n_var = max(n_var)

        n_var_tminus = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result_tminus[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var_tminus.append(np.max(list(vs)) + 1)
        n_var_tminus = max(n_var_tminus)

        
        # n_var_tminus = len(fol_to_cnf(context)['mappings'])

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else:
            max_key = max(mc_t, key=mc_t.get) 
            min_key = min(mc_t, key=mc_t.get)
            bb_l = len(bb[max_key])
            # mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))
            if mc_t[max_key] == 0: return -1000000
            if (mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1)) == 0: breakpoint()
            mc_t = mc_t[max_key]*(n_var-bb_l+1)/(mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1))
            
        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: 
            
            max_key = max(mc_tminus, key=mc_tminus.get)
            min_key = min(mc_tminus, key=mc_tminus.get)
            
            bb_l = len(bb_tminus[max_key])
            mc_tminus = mc_tminus[max_key]*(n_var_tminus-bb_l+1)/(mc_tminus[min_key]*(n_var_tminus-len(bb_tminus[min_key])+1)+mc_tminus[max_key]*(n_var_tminus-bb_l+1))
        
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_t -  mc_tminus
    elif mode == 'raw-bb':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        bb = get_bb(context + step)
        bb_tminus = get_bb(context)

        # n_var = len(fol_to_cnf(context + step)['mappings'])
        result = fol_to_cnf(context+step)
        result_tminus = fol_to_cnf(context)
        n_var = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var.append(np.max(list(vs)) + 1)
        n_var = max(n_var)

        n_var_tminus = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result_tminus[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var_tminus.append(np.max(list(vs)) + 1)
        n_var_tminus = max(n_var_tminus)

        
        # n_var_tminus = len(fol_to_cnf(context)['mappings'])

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else:
            max_key = max(mc_t, key=mc_t.get)
            min_key = min(mc_t, key=mc_t.get)
            bb_l = len(bb[max_key])
            # mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))
            if mc_t[max_key] == 0: return -1000000
            if (mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1)) == 0: breakpoint()
            mc_t = (mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1))
            
        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: 
            
            max_key = max(mc_tminus, key=mc_tminus.get)
            min_key = min(mc_tminus, key=mc_tminus.get)
            
            bb_l = len(bb_tminus[max_key])
            mc_tminus = (mc_tminus[min_key]*(n_var_tminus-len(bb_tminus[min_key])+1)+mc_tminus[max_key]*(n_var_tminus-bb_l+1))
        
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_tminus - mc_t
    elif mode == 'conf-bb':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        bb = get_bb(context + step)
        bb_tminus = get_bb(context)
        result = fol_to_cnf(context+step)
        result_tminus = fol_to_cnf(context)
        n_var = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var.append(np.max(list(vs)) + 1)
        n_var = max(n_var)

        n_var_tminus = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result_tminus[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var_tminus.append(np.max(list(vs)) + 1)
        n_var_tminus = max(n_var_tminus)
        # n_var = len(fol_to_cnf(context + step)['mappings'])
        # n_var_tminus = len(fol_to_cnf(context)['mappings'])

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else:
            max_key = max(mc_t, key=mc_t.get)
            bb_l = len(bb[max_key])
            # mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))
            mc_t = mc_t[max_key]*(n_var-bb_l)

        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: 

            max_key = max(mc_tminus, key=mc_tminus.get)
            bb_l = len(bb_tminus[max_key])
            mc_tminus = mc_tminus[max_key]*(n_var_tminus-bb_l)
        
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_t -  mc_tminus
    elif mode == 'raw-conf':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else: mc_t = np.min(list(mc_t.values()))

        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: mc_tminus = np.min(list(mc_tminus.values()))
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_tminus - mc_t

    elif mode == 'raw':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        mc_t = np.sum(list(mc_t.values()))
        mc_tminus = np.sum(list(mc_tminus.values()))
        return mc_tminus - mc_t
    elif mode == 'even':
        if np.sum(list(mc_t.values())) == 0 or mc_t['pos'] == -1 or mc_t['neg'] == -1: return -1000000
        if np.sum(list(mc_tminus.values())) == 0 or mc_tminus['pos'] == -1 or mc_tminus['neg'] == -1: return -1000000

        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000
        

        return -(np.abs(mc_t['pos'] - mc_t['neg']))
    elif mode == 'raw-cheating':
        mc_t = mc_t[['neg', 'pos'][cheat_idx]]
        mc_tminus = mc_tminus[['neg', 'pos'][cheat_idx]]
        return mc_tminus - mc_t

    # elif mode == 'rawconf:'
    # return get_mc(context + step) - get_mc(context) 
    # return mc_t - mc_tminus
    # return -get_mc(context + step)

def get_mc(text, seedrun=0):
    result = fol_to_cnf(text)
    # if result['pos']
    # breakpoint()
    hash = str(time.time())
    for n in ['pos', 'neg']:
        # if
        output_file = '/home/XXXX/clause/folio_newlogic_/getmc/' + n + '_' + hash + '.cnf'
        # outtext = ''
        file = output_file
        
        write_dimacs_cnf(result[n], output_file)
        o = open(output_file, 'r')
        rl = o.readlines()
        # while rl[0].startswith('')
        n_vars = int(rl[0].split(' ')[2])
        ls = []
        maxlen = 0
        
        for line in rl[1:]:
            try: 
                for lit in line.split(' ')[:-1]:
                    ls.append(np.abs(int(lit)))
                    if ls[-1] > n_vars: n_vars = ls[-1]
            except: breakpoint()
            clauselength = len(line.split(' '))-1
            if clauselength > maxlen: maxlen = clauselength
        # if maxlen < 2: breakpoint()

        
        # o = open(output_file, 'w')
        # o.write(result[n])
        # o.close()

        # output_file = '/home/XXXX/clause/folio_newlogic_/getmc/' + n + '_' + hash + '.maptxt'
        # try:
        #     o = open(output_file, 'w')
        #     o.write(str(result['mappings']))
        #     o.close()
        # except: breakpoint()
    file = '/home/XXXX/clause/folio_newlogic_/getmc/' + hash + '.cnf'
    # file = output_file
    # file = '/home/XXXX/clause/folio_newlogic_/getmc/' + str(time.time()) + '_' + key + '.cnf'
    mc = {'pos':  -1, 'neg': -1}
    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]
    # breakpoint()
    for i in range(len(files)):
        file = files[i]
        try: shutil.copy(file, '/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
        except: breakpoint()
        os.remove(file)
        os.system('timeout 500 /home/XXXX/ganak --appmct 100 --mode 1 '  + '/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]) + '> ' + '/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.mc')
        
        cf = open('/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.mc', 'r')
        # breakpoint()
        lines = cf.readlines()

        el = lines[-2]
        # print(el)
        try:
            ec = el.split('exact arb frac ')[1].strip('\n')
        except:
            print('timeout')
            ec = -1000000
            # breakpoint()
        # breakpoint()
        # breakpoint()
        # lf.close()
        # if ec == '20':
        #     break
        # sl = lines[1:]
        if 'pos' in file: mc['pos'] = int(ec)
        elif 'neg' in file: mc['neg'] = int(ec)
        os.remove('/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.mc')
        
    # for value in mc.values():
    #     if value==1: breakpoint()
    return mc
    

#     return mc
# def get_mc(file, seedrun=0):
#     from concurrent.futures import ThreadPoolExecutor
#     import shutil
#     import os

#     def process_file(file, seedrun):
#         """Process a single file and return its model count"""
#         # Copy file to temp directory
#         temp_dir = '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun)
#         temp_file = temp_dir + '/' + file.split('/')[-1]
#         shutil.copy(file, temp_file)
        
#         # Run ganak
#         output_file = temp_file[:-4] + '.mc'
#         os.system(f'timeout 500 /home/XXXX/ganak --appmct 100 --mode 1 {temp_file} > {output_file}')
        
#         # Read result
#         try:
#             with open(output_file, 'r') as cf:
#                 lines = cf.readlines()
#                 el = lines[-2]
#                 ec = el.split('exact arb frac ')[1].strip('\n')
#                 return int(ec)
#         except:
#             print('timeout')
#             return -1

#     # Prepare file paths
#     base_path = '/'.join(file.split('/')[:-1])
#     filename = file.split('/')[-1]
#     files = {
#         'pos': base_path + '/pos_' + filename,
#         'neg': base_path + '/neg_' + filename
#     }

#     # Process files in parallel
#     mc = {'pos': -1, 'neg': -1}
#     with ThreadPoolExecutor(max_workers=2) as executor:
#         futures = {key: executor.submit(process_file, filepath, seedrun) 
#                 for key, filepath in files.items()}
        
#         for key, future in futures.items():
#             mc[key] = future.result()

#     # Check for value of 1
#     # for value in mc.values():
#     #     if value == 1:
#     #         breakpoint()

#     return mc


def get_bb(text, del_sols=None, seedrun=0, ablate=False):

    bb = {'pos':  [], 'neg': []}



    result = fol_to_cnf(text)
    # if result['pos']
    # breakpoint()
    hash = str(time.time())
    for n in ['pos', 'neg']:
        # if
        output_file = '/home/XXXX/clause/folio_newlogic_/getbb/' + n + '_' + hash + '.cnf'
        # outtext = ''
        file = output_file
        
        write_dimacs_cnf(result[n], output_file)
        o = open(output_file, 'r')
        rl = o.readlines()
        # while rl[0].startswith('')
        n_vars = int(rl[0].split(' ')[2])
        ls = []
        maxlen = 0
        
        for line in rl[1:]:
            try: 
                for lit in line.split(' ')[:-1]:
                    ls.append(np.abs(int(lit)))
                    if ls[-1] > n_vars: n_vars = ls[-1]
            except: breakpoint()
            clauselength = len(line.split(' '))-1
            if clauselength > maxlen: maxlen = clauselength
        # if maxlen < 2: breakpoint()

        
        # o = open(output_file, 'w')
        # o.write(result[n])
        # o.close()

        # output_file = '/home/XXXX/clause/folio_newlogic_/getmc/' + n + '_' + hash + '.maptxt'
        # try:
        #     o = open(output_file, 'w')
        #     o.write(str(result['mappings']))
        #     o.close()
        # except: breakpoint()
    file = '/home/XXXX/clause/folio_newlogic_/getbb/' + hash + '.cnf'
    # file = output_file
    # file = '/home/XXXX/clause/folio_newlogic_/getmc/' + str(time.time()) + '_' + key + '.cnf'
    mc = {'pos':  -1, 'neg': -1}
    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]
    # breakpoint()
    for i in range(len(files)):
        file = files[i]
        try: shutil.copy(file, '/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
        except: breakpoint()
        os.remove(file)
        os.system("timeout 5000 " + USER_PATH + "/LLM-project/cadiback/cadiback " + '/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]) + '> '  + '/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/'+ str(file.split('/')[-1])[:-4] + ".bbone")
        #   
        bbone= open('/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + ".bbone", 'r')
        lines = bbone.readlines()
        # breakpoint()
        for line in lines:
            if line.startswith('b'):
                #   
                lits = line.split(' ')[1:]
                for lit in lits:
                    lit = lit.strip()
                    if lit == '0':
                        continue
                    lit = int(lit)
                    if 'pos' in file:                                
                        if 'neg' in file:
                            print('l. 447 uh oh')
                              
                        bb['pos'].append(lit)
                    elif 'neg' in file:
                            bb['neg'].append(lit)
    # breakpoint()
    return bb


import os, shutil, json, random, subprocess, uuid
import numpy as np
from concurrent.futures import ProcessPoolExecutor

# USER_PATH = "/your/path"   # same as before

def make_tempdir(base_dir):
    tdir = os.path.join(base_dir, f"temp_{uuid.uuid4().hex}")
    os.makedirs(tdir, exist_ok=True)
    return tdir

def run_cadical(infile):
    logfile = infile[:-4] + ".log"
    cmd = [
        f"{USER_PATH}/sat_gen/HardSATGEN/postprocess/cadical/build/cadical",
        infile
    ]
    with open(logfile, "w") as logf:
        subprocess.run(" ".join(cmd), shell=True, stdout=logf, stderr=subprocess.STDOUT)

    with open(logfile, "r") as f:
        lines = f.readlines()
    ec = lines[-1].split("exit ")[-1].strip()
    return int(ec)

def random_clause_worker(input_file):
    # Create isolated directory for this worker
    base_dir = os.path.dirname(input_file)
    tdir = make_tempdir(base_dir)

    # Build positive/negative copies
    pos = os.path.join(tdir, "pos.cnf")
    neg = os.path.join(tdir, "neg.cnf")

    shutil.copy(os.path.join(base_dir, "pos_" + os.path.basename(input_file)), pos)
    shutil.copy(os.path.join(base_dir, "neg_" + os.path.basename(input_file)), neg)

    cnfs = [pos, neg]
    loopcounts = [0, 0]

    # Read number of variables
    with open(pos, "r") as cf:
        header = cf.readline().strip().split()
        n_var = int(header[-2])

    while True:
        clause = [str((1 if random.random() > 0.5 else -1) * random.randint(1, n_var)) for _ in range(3)]
        sat = [1, 1]

        for i, f in enumerate(cnfs):
            exitcode = run_cadical(f)
            if exitcode == 20:
                sat[i] = 0
            
            loopcounts[i] += 1
            if loopcounts[i] > 1000:
                return -1, [open(c).readlines() for c in cnfs], loopcounts[i]
            
            with open(f, "a") as cf:
                cf.write("\n" + " ".join(clause) + " 0")
                
            add_clause(f)
        if 0 in sat:
            break

    if sat == [0,0] or sat == [1,1]:
        breakflag = -1
    else:
        breakflag = int(np.argmin(sat))

    logics = [open(c).readlines() for c in cnfs]
    return breakflag, logics, loopcounts
def random_search(file, n=100):
    results = []
    with ProcessPoolExecutor(max_workers=72) as ex:
        futures = [ex.submit(random_clause_worker, file) for _ in range(n)]
        for f in futures:
            results.append(f.result())

    votes = [r[0] for r in results]
    logics = [r[1] for r in results]
    loopcounts = [r[2] for r in results]

    # Load mapping (unchanged from your code)
    # em = '/'.join(file[:-4].split('/')[:-1]) + '/neg_' + file[:-4].split('/')[-1] + '.maptxt'
    # maptxt = open(em, 'r').read()

    # # arity_file = '/'.join(file[:-4].split('/')[:-1]) + '/neg_' + file[:-4].split('/')[-1] + '.arity'
    # # arity1 = np.load(open(arity_file, 'rb'), allow_pickle=True).item()

    # # arity = {k.lower(): v for k, v in arity1.items()}
    # maptxt = maptxt.replace(" ", " \"").replace(",", "\",").replace(":", "\":").replace("{", "{\"").replace("}", "\"}")
    # mapping = json.loads(maptxt)
    mapping = None

    return votes, logics, mapping, loopcounts

def beam_search_cot(
    prompt: str,
    llm,
    score_fn: Callable[[str, str, str], float],
    beam_width: int = 5,
    num_samples: int = 4,
    max_steps: int = 10,
    stop_token: str = "Final Answer: ",
    mode = 'conf', 
    cheat_idx=None
):
    """
    Beam search over chain-of-thought steps.

    Args:
        prompt: Base problem prompt.
        llm: Object exposing llm.complete(prompt, n) -> List[str]
        score_fn: Function(context, next_step) -> float
        beam_width: Beam size.
        num_samples: Number of next-step samples per beam element.
        max_steps: Maximum CoT length.
        stop_token: Token indicating completion

    Returns:
        Best beam element (dict with steps and score).
    """

    # Each beam item: {"steps": List[str], "score": float}
    beam = [{"steps": [], "score": 0.0, "scorelist": [], 'mc':{}}]
    if mode == 'conf-bb': beam[0]['score'] = int(0)
    completeds = []

    
    # # Then in beam_search_cot:
    # for step_idx in range(max_steps):
    #     candidates = []
        
    #     # Collect all active beam items and their prompts
    #     active_items = []
    #     generation_prompts = []
        
    #     for item in beam:
    #         if item.get('completed', False):
    #             continue
    #         cot_so_far = prompt + '\n' + "\n".join(item["steps"])
    #         generation_prompt = FEWSHOT + cot_so_far + '\n#' + str(step_idx+1) + '. '
            
    #         active_items.append((item, cot_so_far))
    #         generation_prompts.append(generation_prompt)
        
    #     if not active_items:
    #         break
        
    #     # BATCH generate for all active beam items at once
    #     all_completions = llm.generate_next_thought_batch(
    #         generation_prompts, 
    #         step_idx=step_idx+1, 
    #         n_samples=num_samples, 
    #         max_new=100
    #     )
        
    #     # Process completions for each beam item
    #     for beam_idx, (item, cot_so_far) in enumerate(active_items):
    #         completions = all_completions[beam_idx]
    #         next_steps = []
            
    #         for i in range(num_samples):
    #             if FEWSHOT not in completions[i]: 
    #                 breakpoint()
    #             completion = completions[i].split(FEWSHOT)[1]
                
    #             try: 
    #                 next_steps.append(
    #                     '\n#' + str(step_idx+1) + '.' + 
    #                     completion.split(str(step_idx+1) + '.')[1].split('#' + str(step_idx+1+1) + '.')[0].strip('\n')
    #                 )
    #             except: 
    #                 breakpoint()
            
    #         # Score and add candidates
    #         for step in next_steps:
    #             step = step.strip()
                
    #             # Optional early stop
    #             if 'return True' in step or 'return False' in step or stop_token in step:
    #                 new_score = item["score"] 
    #                 candidates.append({
    #                     "steps": item["steps"] + [step],
    #                     "score": new_score,
    #                     'scorelist': item['scorelist'] + [new_score],
    #                     'mc': get_mc(cot_so_far + step),
    #                     "completed": True
    #                 })
    #                 completeds.append(candidates[-1])
    #                 continue
                
    #             step_score = score_fn(cot_so_far, step, mode, cheat_idx=cheat_idx)
    #             candidates.append({
    #                 "steps": item["steps"] + [step],
    #                 "score": item["score"] + step_score,
    #                 'scorelist': item['scorelist'] + [step_score],
    #                 "completed": False
    #             })
        
    #     # Sort by score (descending)
    #     candidates.sort(key=lambda x: x["score"], reverse=True)
        
    #     # Prune
    #     beam = candidates[:beam_width]
        
    #     # Optional: stop if all beams are completed
    #     if all(item.get("completed", False) for item in beam):
    #         break
    # return beam + completeds
    og_mc = get_mc(prompt)
    for step_idx in range(max_steps):
        candidates = []
        
        for item in beam:
            cot_so_far = prompt + '\n' +"\n".join(item["steps"])
            if 'completed' in item.keys() and item['completed'] == True: continue
            # generation_prompt = (
            #     f"{prompt}\n\n"
            #     f"Chain of thought so far:\n{cot_so_far}\n\n"
            #     f"Next step:"
            # 

            
            next_steps = []
            generation_prompt = FEWSHOT + cot_so_far + '\n#' + str(step_idx+1) + '. '
            if num_samples > 20:
                i = 0
                completions = []
                while i < num_samples:
                    completions += llm.generate_next_thought(generation_prompt, step_idx = step_idx+1, n_samples=20, max_new=100)
                    i += 20
            else:
                completions = llm.generate_next_thought(generation_prompt, step_idx = step_idx+1, n_samples=num_samples, max_new=100)

            # completions = llm.complete(generation_prompt, max_new=100, n_samples=num_samples)
            # breakpoint()
            for i in range(len(completions)):
                # if FEWSHOT not in completions[i]: breakpoint()
                completion = completions[i].split('Here is some context')[5]
                # breakpoint()
                try: next_steps.append('\n#' + str(step_idx+1) + '.' + completion.split(str(step_idx+1) + '.')[1].split('#' + str(step_idx+1+1) + '.')[0].strip('\n'))
                except: breakpoint()
                # breakpoint()
            # breakpoint()
            for step in next_steps:
                step = step.strip()

                # Optional early stop
                if 'return' in step or 'return True' in step or 'return False' in step or 'return(True)' in step or 'return(False)' in step or 'Final Answer' in step or stop_token in step:
                    # breakpoint()
                    new_score = item["score"] 
                    candidates.append({
                        'og_mc': og_mc,
                        "steps": item["steps"] + [step],
                        "score": new_score,
                        'scorelist': item['scorelist'] + [new_score],
                        'mc': get_mc(cot_so_far),
                        "completed": True
                    })
                    completeds.append(candidates[-1])
                    continue
                # breakpoint()
                step_score = score_fn(cot_so_far, step, mode, cheat_idx=cheat_idx)


                candidates.append({
                    "og_mc": og_mc,
                    "steps": item["steps"] + [step],
                    "score": item["score"] + step_score,
                    'scorelist': item['scorelist'] + [step_score],
                    "completed": False
                })

        # Sort by score (descending)
        candidates.sort(key=lambda x: x["score"], reverse=True)

        # Prune
        beam = candidates[:beam_width]

        # Optional: stop if all beams are completed
        if all(item.get("completed", False) for item in beam):
            break

    return beam + completeds


if __name__ == '__main__':
    

    import os
    import pickle as pkl
    import json
    # breakpoint()
    # dump = pkl.load( open('/home/XXXX-c/clause/llama70_200_dump.pkl', 'rb')
    dump = pkl.load(open('/home/XXXX/quail_new.pkl', 'rb'))
    # keys = 
    # for key, value in dump.items():
    skipped = 0
    llm = LLM()


    USER_PATH = '/home/XXXX/home/XXXX/6_26_backup/fs_backup_feb13/'

    data =  pkl.load(open('/home/XXXX/clause/dataset/quail.pkl', 'rb'))

    acc = 0
    counter = 0
    # mode = 'raw-cheating'
    # mode = 'conf-bb-norm'
    # mode = 'raw-conf'
    # mode = 'conf-bb'
    # mode = 'raw-bb'
    # config = 'quail-final-answer-conf'
    config = 'quail-5-4'
    mode = 'conf'
    # config = 'llama-8B-quail-newfs-bw-5-ns-60-nodoublepred'
    # config = ''
    oopsies = []
    all_outs = {}
    # all_outs = pkl.load(open('/home/XXXX/clause/beam_outs_' + mode + '_' + config + '.pkl', 'rb'))

    # idxs = list(range(len(dump)))
    # idxs = pkl.load(open('/home/XXXX/quail_idxs.pkl', 'rb'))
    idxs = list(range(len(dump)))
    keys = list(range(len(idxs)))
    random.shuffle(keys)
    # idxs = [289]
    # idxs = []
    # for i in range(len(data)):
    #     if 'false' in data[i]['label']: idxs.append(i)

    # random.shuffle(idxs)
    all_outs['idxs'] = idxs
    logic_fail = []
    # i = 0
    for key in (pbar := tqdm(keys)):
        if key in all_outs.keys(): continue
        value = dump[key]
        # if key < 8: continue
        # if key == 57: continue
        # breakpoint()
        # logic = value['out'].split('user Now, its your turn:')[-1].split('def translation():')[1].strip('```')
        try:
            logic = ('\"\"\"\nHere is some context:\n' + value['out'].split('def translation():')[1].strip('```') + '\n## Let\'s think step by step:' ).replace("# Is the student's answer True or False?","# Instruction: determine if the student's answer is True or False. ").replace('    ', '')
            # breakpoint() 
        except: 
            logic_fail.append(key)
            continue
        # breakpoint()
        label = data[idxs[key]]['label']
        if mode == 'raw-cheating':
            cheat_idx = False
            if label.lower().strip() == 'true': cheat_idx = True
        else: cheat_idx=None
        # breakpoint()
        out = beam_search_cot(prompt=logic, llm=llm, score_fn= score, mode = mode, cheat_idx=cheat_idx)

        all_outs[key] = out

        # breakpoint()
        # for o in out:
        #     if -1000000 in o['scorelist']:
        #         zeroed.append(key)
        #         if counter == 0: cc = 1
        #         pbar.set_description('Mode: ' + mode + ' | n_zeroed:' + str(len(zeroed)) + ' | Acc: ' + str(acc / cc))
        #         break

                # continue
        preds = [0,0]
        for o in out:
            if o['completed'] != True: continue
            # if 'true' in o['steps'][-1].lower():
            #     preds[0] += 1
            # else: preds[1] += 1
            if o['mc']['pos'] > o['mc']['neg']: preds[1] += 1
            elif o['mc']['pos'] < o['mc']['neg']: preds[0] += 1
        if preds == [0,0]:
            oopsies.append(key)
            # all_outs[key]['oopsies'] = True
            all_outs['oopsies'] = oopsies
            for o in out:
                if o['completed'] != True: 
                    # print('incomplete')
                    continue
                if 'true' in o['steps'][-1].lower() and 'false' in o['steps'][-1].lower():
                    if len(o['steps'][-1].lower().split('true')[0]) < len(o['steps'][-1].lower().split('false')[0]): preds[0] += 1
                    elif len(o['steps'][-1].lower().split('true')[0]) > len(o['steps'][-1].lower().split('false')[0]): preds[1] += 1
                elif 'true' in o['steps'][-1].lower():
                    preds[0] += 1
                    # print(i, 'true', o['steps'][-1])
                else: 
                    preds[1] += 1
        predvote = np.argmax(preds)
        if ['true', 'false'][predvote] == label.lower().strip():
            acc += 1
        counter += 1

        pbar.set_description('Mode: ' + mode + ' | config: ' + config + ' | oopsies: ' + str(len(oopsies)) + ' |  Acc: ' + str(acc / counter))
        # breakpoint()
        
        # breakpoint()

        # breakpoint()
        all_outs['logic_fail'] = logic_fail
        pkl.dump(all_outs, open('/home/XXXX/clause/beam_outs_' + mode + '_' + config +  '.pkl', 'wb'))
