﻿from typing import List, Dict, Callable
import math
import os
import shutil
from transformers.generation.utils import GenerationConfig
from transformers import StoppingCriteriaList, GenerationConfig


import numpy as np
import time
import json

devices = [5]
d_str = ''
for d in devices:
    d_str += str(d) + ','
d_str = d_str[:-1]
os.environ["CUDA_VISIBLE_DEVICES"] = d_str


import torch
from torch.utils.data import DataLoader
USER_PATH = '/home/XXXX/home/XXXX/6_26_backup/fs_backup_feb13/'

os.environ["CURL_CA_BUNDLE"]=""
os.environ["REQUESTS_CA_BUNDLE"]=""
# os.environ['TRANSFORMERS_CACHE'] = USER_PATH + '/.cache/huggingface/hub'
cache_dir = '/work/XXXX/'
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HF_HOME'] = cache_dir
# os.environ['HF_HUB_OFFLINE'] ='1'
FEWSHOT = open('/home/XXXX/clause/logic_cot_fewshot.txt', 'r').read()
# import transformers

# import urllib3
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import argparse
from tqdm import tqdm
import time
import datetime
import warnings
import contextlib

import requests
from urllib3.exceptions import InsecureRequestWarning

old_merge_environment_settings = requests.Session.merge_environment_settings

@contextlib.contextmanager
def no_ssl_verification():
    opened_adapters = set()

    def merge_environment_settings(self, url, proxies, stream, verify, cert):
        # Verification happens only once per connection so we need to close
        # all the opened adapters once we're done. Otherwise, the effects of
        # verify=False persist beyond the end of this context manager.
        opened_adapters.add(self.get_adapter(url))

        settings = old_merge_environment_settings(self, url, proxies, stream, verify, cert)
        settings['verify'] = False

        return settings

    requests.Session.merge_environment_settings = merge_environment_settings

    try:
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', InsecureRequestWarning)
            yield
    finally:
        requests.Session.merge_environment_settings = old_merge_environment_settings

        for adapter in opened_adapters:
            try:
                adapter.close()
            except:
                pass

class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

args = {'train_file_path': './example_data', 'test_file_path': './example_data', 'save_path': './../SFT_train_res', 'model_choice': 'meta-llama/Llama-2-13b-chat-hf', 
        'n_rows': 20, 'max_length': 1000,'temperature': 1, 'lr': 5e-05, 'weight_decay': 0.0, 'epochs': 10, 'max_grad_norm': 1.0, 'batch_size': 2, 'save_strategy': 'no', 'use_lora': True}
# args['model_choice'] = 'mistralai/Mistral-7B-Instruct-v0.3'
args['model_choice'] = 'meta-llama/Meta-Llama-3-8B-Instruct'

args = Struct(**args)

from transformers import StoppingCriteria
from transformers import StoppingCriteria

class StopOnNextStepTokens(StoppingCriteria):
    def __init__(self, stop_token_ids, prompt_len, device):
        self.stop_token_ids = torch.tensor(
            stop_token_ids, device=device, dtype=torch.long
        )
        self.prompt_len = prompt_len
    def __call__(self, input_ids, scores, **kwargs):
        # Check only newly generated tokens
        gen = input_ids[:, self.prompt_len:]

        # Stop if ALL sequences end with stop_token_ids
        for seq in gen:
            if seq.shape[0] < len(self.stop_token_ids):
                return False
            if not torch.equal(seq[-len(self.stop_token_ids):], self.stop_token_ids):
                return False
        return True

class StopOnNextStep(StoppingCriteria):
    def __init__(self, tokenizer, step_idx):
        self.tokenizer = tokenizer
        self.next_step = f"\n#{step_idx + 2}."
        self.final = "\nFinal Answer:"

    def __call__(self, input_ids, scores, **kwargs):
        # input_ids shape: (batch_size, seq_len)
        decoded = self.tokenizer.batch_decode(
            input_ids, skip_special_tokens=True
        )

        # Stop ONLY if *all* samples hit a boundary
        done = [
            (self.next_step in text) or (self.final in text)
            for text in decoded
        ]

        return all(done)

class LLM():
    def __init__(self):
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype="bfloat16",
            bnb_4bit_use_double_quant=True,
        )
        with no_ssl_verification():
            

            
            self.tokenizer = AutoTokenizer.from_pretrained(
                    args.model_choice,
                    cache_dir = cache_dir,
                    token = '  ',
                    attn_implementation="flash_attention_2"

                    )
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id  
            self.tokenizer.padding_side = 'left'           

            self.model = AutoModelForCausalLM.from_pretrained(
                    args.model_choice, 
                    cache_dir = cache_dir,
                    quantization_config=quant_config,
                    device_map='auto',
                    token = '  ',
                    attn_implementation="flash_attention_2"
                    )

        self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def sentence_probabilities(self, sentences):
        with torch.no_grad():
            sentence_tokens = self.tokenizer(sentences, return_tensors='pt', padding=True)
            sentence_token_ids = sentence_tokens.input_ids.cuda()

            # Little hack to cut down inference time by 4-5x (leads to some imprecisions when using quantization)
            # Find the common prefix and run it through the model once, to save time
            first_different_token = (sentence_token_ids == sentence_token_ids[0, :].unsqueeze(0)).all(dim=0).long().argmin()
            common_prefix = sentence_token_ids[0, :first_different_token].unsqueeze(0)
            common_prefix_output = self.model(common_prefix, use_cache=True)
            common_prefix_key_values = tuple(tuple(tensor.expand(len(sentences), -1, -1, -1) for tensor in layer) 
                                             for layer in common_prefix_output.past_key_values)

            # Process the rest of the sentences
            rest_outputs = self.model(sentence_token_ids[:, first_different_token:], past_key_values=common_prefix_key_values)
            logits = torch.concat([common_prefix_output.logits.expand(len(sentences), -1, -1), rest_outputs.logits], dim=1).cuda()
            log_probs = logits.log_softmax(-1)
            log_probs = log_probs[:, :-1, :].gather(2, sentence_token_ids[:, 1:][:, :, None]).squeeze(-1).cuda()
            log_probs = (log_probs*sentence_tokens.attention_mask.cuda()[:, 1:]).sum(-1).cpu()
        return log_probs
    def nli(self, sentences, unknown):
        # true_probs = self.sentence_probabilities(sentences + " True.")
        # false_probs = self.sentence_probabilities(sentences + " False.")q[]
        # maybe_probs = self.sentence_probabilities(sentences + " Maybe.")
        if unknown:
            true_probs, maybe_probs, false_probs =  (self.sentence_probabilities([sentences + "(A)", sentences + "(B)", sentences + "(C)"]))
            return {'True': true_probs, 'Maybe': maybe_probs, 'False': false_probs}
        else:
            true_probs, false_probs =  (self.sentence_probabilities([sentences + "True", sentences + "False"]).softmax(-1))
            return {'True': true_probs, 'False': false_probs}
    def yn(self, sentences, norm=True, relaxed=False, obvious=False, fewshot=None, maybe=False):
        yns = []
        for sentence in sentences:
            if fewshot:
                sentence = fewshot + sentence
            
            if relaxed:
                yns.append(sentence + "Most likely")
                yns.append(sentence + "Not necessarily")
            elif obvious:
                yns.append(sentence + "obviously true.")
                yns.append(sentence + "not obviously true.")
            elif maybe:
                yns.append(sentence + "Yes")
                yns.append(sentence + "Maybe")
                yns.append(sentence + "No")
            else:
                yns.append(sentence + "Yes")
                yns.append(sentence + "No")
        # if norm:
        #     norms = self.sentence_probabilities(sentences)
        probs = []
        batch_size = 256
        for i in range(0, len(yns), batch_size):
            if i+batch_size < len(yns):
                probs += list(self.sentence_probabilities(yns[i:i+batch_size]))
            else: 
                probs += list(self.sentence_probabilities(yns[i:]))
        probs=torch.tensor(probs)
        #   
        # probs = (self.sentence_probabilities(yns))
        # probs = torch.exp(probs)
        pyes = []
        pno = []
        pmaybe = []
        if maybe:
            z = 3
        else:
            z = 2
        for i in range(0,len(probs), z):
            # if yns[i] not in cache.keys():
                # yes, no = self.sentence_probabilities([yns[i], yns[i+1]])
            
            if maybe:
                
                yes, maybe, no = probs[i], probs[i+1], probs[i+2]
                
                      
            else:
                yes, no = probs[i], probs[i+1]
            if norm:
                if maybe: 
                    y,m,n = torch.tensor([yes, maybe, no]).softmax(-1)
                else:
                    y,n = torch.tensor([yes, no]).softmax(-1)
              
                # cache[yns[i]] = y
                # cache[yns[i+1]] = n
                pyes.append(y)
                pno.append(n)
                if maybe:
                    pmaybe.append(m)
            else:
                pyes.append(1-yes/(yes + no))
            # else:
            #     y, n = cache[yns[i]], cache[yns[i+1]]
            #     pyes.append(y)
                # pno.append(n)/
        # print('cache length', len(cache))
        # if maybe:
        
        if maybe: return torch.stack([torch.tensor(pyes), torch.tensor(pmaybe), torch.tensor(pno)])
        return torch.tensor(pyes), torch.tensor(pmaybe), torch.tensor(pno)
    def complete(
        self,
        prompt,
        max_new=25,
        temp=1.0,
        topk=0,
        n_samples=1,   # ← NEW
    ):
        encode_ids = self.tokenizer(
            prompt,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=3000
        ).input_ids.cuda()

        gc = GenerationConfig(
            temperature=temp,
            do_sample=True,
            top_k=topk,
            max_new_tokens=max_new,
            num_return_sequences=n_samples,   # ← KEY LINE
            return_dict_in_generate=True,
            output_scores=False,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        outputs = self.model.generate(
            encode_ids,
            generation_config=gc,
        )

        responses = self.tokenizer.batch_decode(
            outputs.sequences,
            skip_special_tokens=True
        )

        return responses
    

    def generate_next_thought(self,
        prompt,
        step_idx,
        n_samples=8,
        max_new=100,
        temp=1.0,
        topk=50,
    ):
        # input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        # stopper = StopOnNextStep(self.tokenizer, step_idx)
    
        gc = GenerationConfig(
            do_sample=True,
            temperature=temp,
            top_k=topk,
            max_new_tokens=max_new,
            num_return_sequences=n_samples,
            pad_token_id=self.tokenizer.eos_token_id,
        )
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        stop_str = f"\n#{step_idx + 2}."
        stop_ids = self.tokenizer.encode(stop_str, add_special_tokens=False)

        stopper = StopOnNextStepTokens(
            stop_token_ids=stop_ids,
            prompt_len=input_ids.shape[1],
            device=input_ids.device,
        )

        out = self.model.generate(
            input_ids,
            generation_config=gc,
            stopping_criteria=StoppingCriteriaList([stopper]),
        )

        # out =self.model.generate(input_ids, generation_config=gc, stopping_criteria=StoppingCriteriaList([stopper]))

        texts = self.tokenizer.batch_decode(out, skip_special_tokens=True)

        return texts
    # def generate_next_thought_batch(self, prompts, step_idx, n_samples=8, max_new=100, temp=1, topk=0):
    #     """
    #     Generate next thoughts for multiple prompts at once.
    #     Returns list of lists: outer list = prompts, inner list = n_samples per prompt
    #     """
    #     # Tokenize all prompts with padding
    #     inputs = self.tokenizer(
    #         prompts, 
    #         return_tensors="pt", 
    #         padding=True,
    #         truncation=True,
    #         max_length=3000
    #     ).to('cuda')
        
    #     stop_str = f"\n#{step_idx + 2}."
    #     stop_ids = self.tokenizer.encode(stop_str, add_special_tokens=False)
        
    #     stopper = StopOnNextStepTokens(
    #         stop_token_ids=stop_ids,
    #         prompt_len=inputs.input_ids.shape[1],
    #         device=inputs.input_ids.device,
    #     )
        
    #     gc = GenerationConfig(
    #         do_sample=True,
    #         temperature=temp,
    #         top_k=topk,
    #         max_new_tokens=max_new,
    #         num_return_sequences=n_samples,
    #         pad_token_id=self.tokenizer.eos_token_id,
    #     )
        
    #     out = self.model.generate(
    #         inputs.input_ids,
    #         attention_mask=inputs.attention_mask,
    #         generation_config=gc,
    #         stopping_criteria=StoppingCriteriaList([stopper]),
    #     )
        
    #     texts = self.tokenizer.batch_decode(out, skip_special_tokens=True)
        
    #     # Reshape: texts is flat list of len(prompts) * n_samples
    #     # Reorganize into list of lists
    #     result = []
    #     for i in range(len(prompts)):
    #         result.append(texts[i * n_samples : (i + 1) * n_samples])
        
    #     return result




import re
from typing import List, Dict, Tuple, Set, Any
from itertools import product

class FOLToCNF:
    def __init__(self):
        self.symbol_counter = 0
        self.mappings = {}
        self.reverse_mappings = {}
        self.facts = []
        self.rules = []
        self.formulas = []  # Non-quantified formulas (Or, And, Not combinations)
        self.query = None
        self.constants = set()
        self.predicates = {}
        
    def normalize_predicate(self, pred_name: str, args: List[str]) -> str:
        """Normalize a predicate to a consistent string format (no spaces after commas)."""
        return f"{pred_name}({','.join(args)})"
    
    def get_symbol(self, name: str) -> str:
        """Get or create a propositional symbol for a ground predicate."""
        # Normalize the name to ensure consistency
        name = name.replace(', ', ',').replace(' ,', ',')
        
        if name not in self.mappings:
            symbol = f"p{self.symbol_counter}"
            self.symbol_counter += 1
            self.mappings[name] = symbol
            self.reverse_mappings[symbol] = name
        return self.mappings[name]
    
    def parse_predicate(self, pred_str: str) -> Tuple[str, List[str]]:
        """Parse a predicate string like 'previous_stop(x, y)' into (name, [args])."""
        match = re.match(r'(\w+)\(([^)]+)\)', pred_str.strip())
        if match:
            pred_name = match.group(1)
            args = [arg.strip() for arg in match.group(2).split(',')]
            return pred_name, args
        return None, []
    
    def extract_predicates_from_and(self, and_content: str) -> List[Tuple[str, List[str]]]:
        """Extract all predicates from And(...) content."""
        predicates = []
        pattern = r'(\w+)\(([^)]+)\)'
        for match in re.finditer(pattern, and_content):
            pred_name = match.group(1)
            args_str = match.group(2)
            args = [arg.strip() for arg in args_str.split(',')]
            predicates.append((pred_name, args))
        return predicates
    
    # def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
    #     """
    #     Parse a propositional formula and convert it to CNF.
    #     Handles Or, And, Not combinations of ground predicates.
    #     """
    #     # This is a simplified conversion - for complex cases you'd need a full formula parser
    #     # We'll handle the most common patterns
        
    #     formula_str = formula_str.strip()
        
    #     # Handle Or(And(...), And(...)) - disjunction of conjunctions
    #     if formula_str.startswith('Or(And('):
    #         return self.parse_or_of_ands(formula_str)
        
    #     # Handle simple And(...) - already in CNF (single clause with all literals)
    #     if formula_str.startswith('And('):
    #         return self.parse_simple_and(formula_str)
        
    #     # Handle simple Or(...) - single clause
    #     if formula_str.startswith('Or('):
    #         return self.parse_simple_or(formula_str)
        
    #     return []
    # def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
    #     """
    #     Parse a propositional formula and convert it to CNF.
    #     Handles Or, And, Not combinations of ground predicates.
    #     """
    #     formula_str = formula_str.strip()
        
    #     # Handle Or(And(...), And(...)) - disjunction of conjunctions
    #     if formula_str.startswith('Or(And('):
    #         return self.parse_or_of_ands(formula_str)
        
    #     # Handle simple And(...) - already in CNF (single clause with all literals)
    #     if formula_str.startswith('And('):
    #         return self.parse_simple_and(formula_str)
        
    #     # Handle simple Or(...) - single clause
    #     if formula_str.startswith('Or('):
    #         return self.parse_simple_or(formula_str)
        
    #     # Handle simple Not(...) - single negated literal
    #     if formula_str.startswith('Not('):
    #         return self.parse_simple_not(formula_str)
        
    #     # Handle single predicate (no logical operator)
    #     if '(' in formula_str and not any(formula_str.startswith(op) for op in ['Or(', 'And(', 'Not(']):
    #         literals = self.extract_literals_from_formula(formula_str)
    #         if literals:
    #             return [[literals[0]]]  # Single unit clause
        
    #     return []

    def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
        """
        Parse a propositional formula and convert it to CNF.
        Handles Or, And, Not, Implies combinations of ground predicates.
        """
        formula_str = formula_str.strip()
        
        # Handle Implies - convert to Or(Not(A), B)
        if formula_str.startswith('Implies('):
            return self.parse_implies_formula(formula_str)
        
        # Handle Or(And(...), And(...)) - disjunction of conjunctions
        if formula_str.startswith('Or(And('):
            return self.parse_or_of_ands(formula_str)
        
        # Handle simple And(...) - already in CNF (single clause with all literals)
        if formula_str.startswith('And('):
            return self.parse_simple_and(formula_str)
        
        # Handle simple Or(...) - single clause
        if formula_str.startswith('Or('):
            return self.parse_simple_or(formula_str)
        
        # Handle simple Not(...) - single negated literal
        if formula_str.startswith('Not('):
            return self.parse_simple_not(formula_str)
        
        # Handle single predicate (no logical operator)
        if '(' in formula_str and not any(formula_str.startswith(op) for op in ['Or(', 'And(', 'Not(', 'Implies(']):
            literals = self.extract_literals_from_formula(formula_str)
            if literals:
                return [[literals[0]]]  # Single unit clause
        
        return []

    def parse_implies_formula(self, formula_str: str) -> List[List[str]]:
        """
        Parse Implies(A, B) and convert to CNF.
        Implies(A, B) ≡ Or(Not(A), B) ≡ ~A ∨ B
        """
        # Extract the two parts of the implication
        inner_start = 8  # len('Implies(')
        depth = 0
        i = inner_start
        comma_pos = -1
        
        # Find the comma separating antecedent and consequent at depth 0
        while i < len(formula_str):
            if formula_str[i] == '(':
                depth += 1
            elif formula_str[i] == ')':
                depth -= 1
            elif formula_str[i] == ',' and depth == 0:
                comma_pos = i
                break
            i += 1
        
        if comma_pos == -1:
            return []
        
        antecedent = formula_str[inner_start:comma_pos].strip()
        # Find the end of consequent (last closing paren)
        consequent_start = comma_pos + 1
        depth = 0
        i = len(formula_str) - 1
        while i >= consequent_start and formula_str[i] != ')':
            i -= 1
        consequent = formula_str[consequent_start:i].strip()
        
        # Convert Implies(A, B) to Or(Not(A), B)
        # First, negate the antecedent
        if antecedent.startswith('Not('):
            # Not(Not(A)) → A (double negation)
            # Extract what's inside the Not
            inner_start_not = 4
            depth = 1
            j = inner_start_not
            while j < len(antecedent) and depth > 0:
                if antecedent[j] == '(':
                    depth += 1
                elif antecedent[j] == ')':
                    depth -= 1
                j += 1
            antecedent_inner = antecedent[inner_start_not:j-1]
            negated_antecedent = antecedent_inner
        else:
            negated_antecedent = f"Not({antecedent})"
        
        # Now parse Or(Not(A), B)
        or_formula = f"Or({negated_antecedent}, {consequent})"
        return self.parse_simple_or(or_formula)

    def parse_simple_not(self, formula_str: str) -> List[List[str]]:
        """Parse Not(...) and apply De Morgan's laws if needed."""
        # Extract what's inside the Not()
        if not formula_str.startswith('Not('):
            return []
        
        inner_start = 4  # len('Not(')
        # Find the content inside Not(...)
        depth = 1
        i = inner_start
        while i < len(formula_str) and depth > 0:
            if formula_str[i] == '(':
                depth += 1
            elif formula_str[i] == ')':
                depth -= 1
            i += 1
        
        inner_content = formula_str[inner_start:i-1].strip()
        
        # Apply De Morgan's laws
        # Not(And(A, B, ...)) → Or(Not(A), Not(B), ...) → single clause [~A, ~B, ...]
        if inner_content.startswith('And('):
            literals = self.extract_literals_from_formula(inner_content)
            # Negate all literals
            negated_literals = []
            for lit in literals:
                if lit.startswith('~'):
                    negated_literals.append(lit[1:])  # Remove negation
                else:
                    negated_literals.append(f"~{lit}")  # Add negation
            return [negated_literals]  # Single clause with all negated literals
        
        # Not(Or(A, B, ...)) → And(Not(A), Not(B), ...) → multiple unit clauses [[~A], [~B], ...]
        elif inner_content.startswith('Or('):
            literals = self.extract_literals_from_formula(inner_content)
            # Negate all literals and return as separate clauses
            negated_clauses = []
            for lit in literals:
                if lit.startswith('~'):
                    negated_clauses.append([lit[1:]])  # Remove negation
                else:
                    negated_clauses.append([f"~{lit}"])  # Add negation
            return negated_clauses
        
        # Not(Not(A)) → A (double negation elimination)
        elif inner_content.startswith('Not('):
            # Recursively parse the inner Not, then negate the result
            inner_cnf = self.parse_simple_not(inner_content)
            # Negate the result
            result = []
            for clause in inner_cnf:
                negated_clause = []
                for lit in clause:
                    if lit.startswith('~'):
                        negated_clause.append(lit[1:])
                    else:
                        negated_clause.append(f"~{lit}")
                result.append(negated_clause)
            return result
        
        # Not(predicate(...)) → single negated literal
        else:
            literals = self.extract_literals_from_formula(formula_str)
            if literals:
                return [[literals[0]]]  # Should be a single negated literal
            return []
    def parse_simple_and(self, formula_str: str) -> List[List[str]]:
        """Parse And(a, b, c) -> [[a], [b], [c]]"""
        literals = self.extract_literals_from_formula(formula_str)
        return [[lit] for lit in literals]
    
    def parse_simple_or(self, formula_str: str) -> List[List[str]]:
        """Parse Or(a, b, c) -> [[a, b, c]]"""
        literals = self.extract_literals_from_formula(formula_str)
        return [literals]
    
    def parse_or_of_ands(self, formula_str: str) -> List[List[str]]:
        """
        Parse Or(And(a,b), And(c,d)) and convert to CNF.
        This requires distribution: (a∧b)∨(c∧d) = (a∨c)∧(a∨d)∧(b∨c)∧(b∨d)
        """
        # Extract the And(...) components
        and_groups = []
        depth = 0
        current = []
        in_and = False
        
        i = 3  # Skip "Or("
        while i < len(formula_str):
            if formula_str[i:i+4] == 'And(':
                in_and = True
                depth = 1
                i += 4
                start = i
                and_content = []
                
                while depth > 0 and i < len(formula_str):
                    if formula_str[i] == '(':
                        depth += 1
                    elif formula_str[i] == ')':
                        depth -= 1
                    i += 1
                
                and_str = formula_str[start:i-1]
                and_literals = self.extract_literals_from_formula('And(' + and_str + ')')
                and_groups.append(and_literals)
            else:
                i += 1
        
        # Distribute: generate all combinations
        if len(and_groups) == 0:
            return []
        
        # Create CNF by taking one literal from each And group
        clauses = []
        for combo in product(*and_groups):
            clauses.append(list(combo))
        
        return clauses
    
    # def extract_literals_from_formula(self, formula_str: str) -> List[str]:
    #     """Extract all literals from a formula, handling Not(...)."""
    #     literals = []
        
    #     # Find all Not(...) patterns
    #     not_pattern = r'Not\((\w+)\(([^)]+)\)\)'
    #     for match in re.finditer(not_pattern, formula_str):
    #         pred_name = match.group(1)
    #         args_str = match.group(2)
    #         args = [arg.strip() for arg in args_str.split(',')]
    #         pred_str = self.normalize_predicate(pred_name, args)
    #         symbol = self.get_symbol(pred_str)
    #         literals.append(f"~{symbol}")
            
    #         # Track constants
    #         self.constants.update(args)
        
    #     # Find all positive predicates (not preceded by Not)
    #     # Remove the Not(...) parts first
    #     cleaned = re.sub(r'Not\([^)]+\)', '', formula_str)
        
    #     pred_pattern = r'(\w+)\(([^)]+)\)'
    #     for match in re.finditer(pred_pattern, cleaned):
    #         pred_name = match.group(1)
    #         # Skip logical operators
    #         if pred_name in ['Or', 'And', 'Not', 'Implies', 'ForAll']:
    #             continue
    #         args_str = match.group(2)
    #         args = [arg.strip() for arg in args_str.split(',')]
    #         pred_str = self.normalize_predicate(pred_name, args)
    #         symbol = self.get_symbol(pred_str)
    #         literals.append(symbol)
            
    #         # Track constants
    #         self.constants.update(args)
        
    #     return literals
    def extract_literals_from_formula(self, formula_str: str) -> List[str]:
        """Extract all literals from a formula, handling Not(...)."""
        literals = []
        
        # Find all Not(predicate(...)) patterns
        # We need to handle the nested parentheses properly
        i = 0
        positions_to_remove = []  # Track what we've processed as Not
        
        while i < len(formula_str):
            if formula_str[i:i+4] == 'Not(':
                # Found a Not - now find the complete predicate inside
                start = i
                i += 4
                
                # Find the predicate name
                pred_match = re.match(r'(\w+)\(', formula_str[i:])
                if pred_match:
                    pred_name = pred_match.group(1)
                    i += len(pred_match.group(0))
                    
                    # Now find the arguments by tracking parentheses
                    depth = 1
                    arg_start = i
                    while i < len(formula_str) and depth > 0:
                        if formula_str[i] == '(':
                            depth += 1
                        elif formula_str[i] == ')':
                            depth -= 1
                        i += 1
                    
                    # Extract arguments
                    args_str = formula_str[arg_start:i-1]
                    args = [arg.strip() for arg in args_str.split(',')]
                    pred_str = self.normalize_predicate(pred_name, args)
                    symbol = self.get_symbol(pred_str)
                    literals.append(f"~{symbol}")
                    
                    # Track constants
                    self.constants.update(args)
                    
                    # Skip the closing paren of Not()
                    if i < len(formula_str) and formula_str[i] == ')':
                        i += 1
                    
                    # Remember this range to skip when looking for positive predicates
                    positions_to_remove.append((start, i))
                else:
                    i += 1
            else:
                i += 1
        
        # Now find positive predicates, skipping the Not(...) regions
        i = 0
        while i < len(formula_str):
            # Check if we're in a Not region
            in_not_region = any(start <= i < end for start, end in positions_to_remove)
            if in_not_region:
                i += 1
                continue
            
            # Look for pattern: word followed by (
            match = re.match(r'(\w+)\(', formula_str[i:])
            if match:
                pred_name = match.group(1)
                
                # Skip logical operators
                if pred_name in ['Or', 'And', 'Not', 'Implies', 'ForAll']:
                    i += len(match.group(0))
                    continue
                
                # This is a real predicate - extract its arguments
                start = i + len(match.group(0))
                depth = 1
                j = start
                
                while j < len(formula_str) and depth > 0:
                    if formula_str[j] == '(':
                        depth += 1
                    elif formula_str[j] == ')':
                        depth -= 1
                    j += 1
                
                # Extract arguments
                args_str = formula_str[start:j-1]
                args = [arg.strip() for arg in args_str.split(',')]
                pred_str = self.normalize_predicate(pred_name, args)
                symbol = self.get_symbol(pred_str)
                literals.append(symbol)
                
                # Track constants
                self.constants.update(args)
                
                i = j
            else:
                i += 1
        
        return literals
    # def parse_z3_code(self, code: str):
    #     """Parse Z3 code and extract facts, rules, and query."""
    #     lines = code.strip().split('\n')
        
    #     for line in lines:
    #         line = line.strip()
    #         if not line or line.startswith('#'):
    #             continue
                
    #         # Check for ForAll (rules)
    #         if line.startswith('ForAll'):
    #             self.parse_rule(line)
    #         # Check for return statement (query)
    #         elif line.startswith('return'):
    #             query_match = re.search(r'return\s+(.+)', line)
    #             if query_match:
    #                 self.query = query_match.group(1).strip()
    #         # Check for complex formulas (Or, And with predicates)
    #         elif any(line.startswith(op) for op in ['Or(', 'And(', 'Not(']):
    #             self.formulas.append(line)
    #         # Otherwise it's a simple fact
    #         else:
    #             pred_name, args = self.parse_predicate(line)
    #             if pred_name:
    #                 # Normalize and store
    #                 normalized = self.normalize_predicate(pred_name, args)
    #                 # Parse it back to get clean args
    #                 pred_name, args = self.parse_predicate(normalized)
    #                 self.facts.append((pred_name, args))
    #                 self.constants.update(args)
    #                 if pred_name not in self.predicates:
    #                     self.predicates[pred_name] = len(args)
    def parse_z3_code(self, code: str):
        """Parse Z3 code and extract facts, rules, and query."""
        lines = code.strip().split('\n')
        
        for line in lines:
            line = line.strip()
            if not line or line.startswith('#'):
                continue
                
            # Check for ForAll (rules)
            if line.startswith('ForAll'):
                self.parse_rule(line)
            # Check for return statement (query)
            elif line.startswith('return'):
                query_match = re.search(r'return\s+(.+)', line)
                if query_match:
                    self.query = query_match.group(1).strip()
            # Check for Implies at top level (ground implication)
            elif line.startswith('Implies('):
                self.formulas.append(line)
            # Check for complex formulas (Or, And, Not with predicates)
            elif any(line.startswith(op) for op in ['Or(', 'And(', 'Not(']):
                self.formulas.append(line)
            # Otherwise it's a simple fact
            else:
                pred_name, args = self.parse_predicate(line)
                if pred_name:
                    # Normalize and store
                    normalized = self.normalize_predicate(pred_name, args)
                    # Parse it back to get clean args
                    pred_name, args = self.parse_predicate(normalized)
                    self.facts.append((pred_name, args))
                    self.constants.update(args)
                    if pred_name not in self.predicates:
                        self.predicates[pred_name] = len(args)

                        
    def parse_rule(self, line: str):
        """Parse a ForAll rule with better parsing."""
        # Extract variables
        vars_match = re.search(r'ForAll\(\[([^\]]+)\]', line)
        if not vars_match:
            return
        
        variables = [v.strip() for v in vars_match.group(1).split(',')]
        
        # Find the Implies or Or part
        implies_match = re.search(r'Implies\((.*)\)\s*\)', line)
        or_match = re.search(r'Or\((.*)\)\s*\)', line) if not implies_match else None
        
        if implies_match:
            self.parse_implies_rule(variables, implies_match.group(1))
        elif or_match:
            # ForAll([x], Or(...)) is already in clause form
            self.parse_or_rule(variables, or_match.group(1))
    
    def parse_or_rule(self, variables: List[str], or_body: str):
        """Parse ForAll([x], Or(a, b, c)) - already a clause."""
        literals_info = []
        
        # Extract Not(...) predicates
        not_pattern = r'Not\((\w+)\(([^)]+)\)\)'
        for match in re.finditer(not_pattern, or_body):
            pred_name = match.group(1)
            args_str = match.group(2)
            args = [arg.strip() for arg in args_str.split(',')]
            literals_info.append((pred_name, args, True))  # True = negated
            if pred_name not in self.predicates:
                self.predicates[pred_name] = len(args)
        
        # Extract positive predicates
        cleaned = re.sub(r'Not\([^)]+\)', '', or_body)
        pred_pattern = r'(\w+)\(([^)]+)\)'
        for match in re.finditer(pred_pattern, cleaned):
            pred_name = match.group(1)
            if pred_name in ['Or', 'And', 'Not', 'Implies']:
                continue
            args_str = match.group(2)
            args = [arg.strip() for arg in args_str.split(',')]
            literals_info.append((pred_name, args, False))  # False = positive
            if pred_name not in self.predicates:
                self.predicates[pred_name] = len(args)
        
        # Store as a special kind of rule (disjunction)
        self.rules.append((variables, [], literals_info))  # Empty antecedent, list of literals
    
    def parse_implies_rule(self, variables: List[str], implies_body: str):
        """Parse Implies(And(...), consequent) or Implies(pred, consequent)."""
        # Split by comma at the top level
        depth = 0
        parts = []
        current = []
        
        for char in implies_body:
            if char == '(':
                depth += 1
                current.append(char)
            elif char == ')':
                depth -= 1
                current.append(char)
            elif char == ',' and depth == 0:
                parts.append(''.join(current).strip())
                current = []
            else:
                current.append(char)
        
        if current:
            parts.append(''.join(current).strip())
        
        if len(parts) != 2:
            return
        
        antecedent_part = parts[0].strip()
        consequent_part = parts[1].strip()
        
        # Parse antecedent
        antecedents = []
        if antecedent_part.startswith('And('):
            and_match = re.match(r'And\((.*)\)', antecedent_part)
            if and_match:
                antecedents = self.extract_predicates_from_and(and_match.group(1))
        else:
            # Single predicate antecedent
            pred_name, args = self.parse_predicate(antecedent_part)
            if pred_name:
                antecedents = [(pred_name, args)]
        
        # Track predicates
        for pred_name, args in antecedents:
            if pred_name not in self.predicates:
                self.predicates[pred_name] = len(args)
        
        # Parse consequent
        consequent_name, consequent_args = self.parse_predicate(consequent_part)
        if not consequent_name:
            return
        
        if consequent_name not in self.predicates:
            self.predicates[consequent_name] = len(consequent_args)
        
        self.rules.append((variables, antecedents, (consequent_name, consequent_args)))
    
    def ground_predicate(self, pred_name: str, args: List[str], substitution: Dict[str, str]) -> str:
        """Ground a predicate with substitution and return normalized form."""
        grounded_args = [substitution.get(arg, arg) for arg in args]
        return self.normalize_predicate(pred_name, grounded_args)
    
    def ground_rules(self) -> List[List[str]]:
        """Ground all rules with all possible constant substitutions."""
        clauses = []
        constants_list = list(self.constants)
        
        for variables, antecedents, consequent in self.rules:
            # Generate all possible substitutions
            for substitution_values in product(constants_list, repeat=len(variables)):
                substitution = dict(zip(variables, substitution_values))
                
                # Check if this is a disjunction rule (empty antecedents, consequent is list)
                if len(antecedents) == 0 and isinstance(consequent, list):
                    # This is ForAll([x], Or(...)) form
                    clause = []
                    for pred_name, args, is_negated in consequent:
                        grounded = self.ground_predicate(pred_name, args, substitution)
                        symbol = self.get_symbol(grounded)
                        if is_negated:
                            clause.append(f"~{symbol}")
                        else:
                            clause.append(symbol)
                    clauses.append(clause)
                else:
                    # Standard Implies rule: ~A1 v ~A2 v ... v C
                    clause = []
                    
                    # Add negated antecedents
                    for pred, args in antecedents:
                        grounded = self.ground_predicate(pred, args, substitution)
                        clause.append(f"~{self.get_symbol(grounded)}")
                    
                    # Add consequent
                    grounded_consequent = self.ground_predicate(consequent[0], consequent[1], substitution)
                    clause.append(self.get_symbol(grounded_consequent))
                    
                    clauses.append(clause)
        
        return clauses
    
    def convert_to_cnf(self, negate_query: bool = False) -> Tuple[List[List[str]], Dict[str, str]]:
        """
        Convert the FOL problem to CNF.
        
        Args:
            negate_query: If True, negate the query (to prove it's true via contradiction)
                         If False, assert the query (to prove it's false via contradiction)
        
        Returns:
            Tuple of (clauses, mappings)
        """
        clauses = []
        
        # Add facts as unit clauses
        for pred_name, args in self.facts:
            fact_str = self.normalize_predicate(pred_name, args)
            symbol = self.get_symbol(fact_str)
            clauses.append([symbol])
        
        # Add grounded rules
        clauses.extend(self.ground_rules())
        
        # Add non-quantified formulas converted to CNF
        for formula in self.formulas:
            formula_cnf = self.parse_formula_to_cnf(formula)
            clauses.extend(formula_cnf)
        
        # Add query (negated for proving true, asserted for proving false)
        if self.query:
            # Convert query to CNF
            query_cnf = self.parse_formula_to_cnf(self.query)
            # breakpoint()
            if negate_query:
                query_cnf = self.parse_formula_to_cnf('Not(' + self.query + ')')
                # Negate the entire query
                # If query is a single literal, just negate it
            #     if len(query_cnf) == 1 and len(query_cnf[0]) == 1:
            #         lit = query_cnf[0][0]
            #         if lit.startswith('~'):
            #             clauses.append([lit[1:]])
            #         else:
            #             clauses.append([f"~{lit}"])
            #     else:
            #         # For complex queries, we need to negate the whole formula
            #         # This is a simplification - proper negation requires DeMorgan's laws
            #         for clause in query_cnf:
            #             negated_clause = []
            #             for lit in clause:
            #                 if lit.startswith('~'):
            #                     negated_clause.append(lit[1:])
            #                 else:
            #                     negated_clause.append(f"~{lit}")
            #             clauses.append(negated_clause)
            # else:
            #     clauses.extend(query_cnf)
            clauses.extend(query_cnf)
        
        return clauses, dict(self.mappings)
# import re
# from typing import List, Dict, Tuple, Set, Any
# from itertools import product

# class FOLToCNF:
#     def __init__(self):
#         self.symbol_counter = 0
#         self.mappings = {}
#         self.reverse_mappings = {}
#         self.facts = []
#         self.rules = []
#         self.formulas = []  # Non-quantified formulas (Or, And, Not combinations)
#         self.query = None
#         self.constants = set()
#         self.predicates = {}
        
#     def normalize_predicate(self, pred_name: str, args: List[str]) -> str:
#         """Normalize a predicate to a consistent string format (no spaces after commas)."""
#         return f"{pred_name}({','.join(args)})"
    
#     def get_symbol(self, name: str) -> str:
#         """Get or create a propositional symbol for a ground predicate."""
#         try:
#             # Normalize the name to ensure consistency
#             name = name.replace(', ', ',').replace(' ,', ',')
            
#             # Validate that this is actually a predicate, not a formula
#             # If it's a formula, return a placeholder symbol and continue
#             if any(name.startswith(op) for op in ['Implies(', 'ForAll(', 'Exists(', 'Or(', 'And(']):
#                 # Create a special symbol for malformed input
#                 if name not in self.mappings:
#                     symbol = f"p{self.symbol_counter}"
#                     self.symbol_counter += 1
#                     self.mappings[name] = symbol
#                     self.reverse_mappings[symbol] = name
#                 return self.mappings[name]
            
#             if name not in self.mappings:
#                 symbol = f"p{self.symbol_counter}"
#                 self.symbol_counter += 1
#                 self.mappings[name] = symbol
#                 self.reverse_mappings[symbol] = name
#             return self.mappings[name]
#         except Exception:
#             # If anything goes wrong, return a default symbol
#             return "p_error"
    
#     def parse_predicate(self, pred_str: str) -> Tuple[str, List[str]]:
#         """Parse a predicate string like 'previous_stop(x, y)' into (name, [args])."""
#         try:
#             pred_str = pred_str.strip()
            
#             # Reject formulas - they should not be parsed as predicates
#             if any(pred_str.startswith(op) for op in ['Implies(', 'ForAll(', 'Exists(', 'Or(', 'And(', 'Not(']):
#                 return None, []
            
#             match = re.match(r'(\w+)\(([^)]+)\)', pred_str)
#             if match:
#                 pred_name = match.group(1)
                
#                 # Validate that pred_name is not a logical operator
#                 if pred_name in ['Implies', 'ForAll', 'Exists', 'Or', 'And', 'Not']:
#                     return None, []
                
#                 args = [arg.strip() for arg in match.group(2).split(',')]
#                 return pred_name, args
#             return None, []
#         except Exception:
#             return None, []
    
#     def extract_predicates_from_and(self, and_content: str) -> List[Tuple[str, List[str]]]:
#         """Extract all predicates from And(...) content."""
#         predicates = []
#         pattern = r'(\w+)\(([^)]+)\)'
#         for match in re.finditer(pattern, and_content):
#             pred_name = match.group(1)
#             args_str = match.group(2)
#             args = [arg.strip() for arg in args_str.split(',')]
#             predicates.append((pred_name, args))
#         return predicates
    
#     def extract_top_level_args(self, content: str) -> List[str]:
#         """
#         Extract top-level comma-separated arguments, respecting parentheses nesting.
#         E.g., "a, Or(b, c), d" → ["a", "Or(b, c)", "d"]
#         """
#         try:
#             args = []
#             current = []
#             depth = 0
            
#             for char in content:
#                 if char == '(':
#                     depth += 1
#                     current.append(char)
#                 elif char == ')':
#                     depth -= 1
#                     current.append(char)
#                 elif char == ',' and depth == 0:
#                     args.append(''.join(current).strip())
#                     current = []
#                 else:
#                     current.append(char)
            
#             if current:
#                 args.append(''.join(current).strip())
            
#             return args
#         except Exception:
#             # If parsing fails, return the whole content as a single arg
#             return [content] if content else []
    
#     def parse_formula_to_cnf(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse a propositional formula and convert it to CNF.
#         Handles Or, And, Not, Implies combinations of ground predicates.
#         """
#         try:
#             formula_str = formula_str.strip()
            
#             if not formula_str:
#                 return []
            
#             # Handle Exists - ground with all constants
#             if formula_str.startswith('Exists:'):
#                 return self.parse_exists_formula(formula_str)
            
#             # Handle Implies - convert to Or(Not(A), B)
#             if formula_str.startswith('Implies('):
#                 return self.parse_implies_to_cnf(formula_str)
            
#             # Handle And - need to check if it contains nested Or/And
#             if formula_str.startswith('And('):
#                 return self.parse_and_to_cnf(formula_str)
            
#             # Handle Or - single clause or needs distribution
#             if formula_str.startswith('Or('):
#                 return self.parse_or_to_cnf(formula_str)
            
#             # Handle Not
#             if formula_str.startswith('Not('):
#                 return self.parse_not_to_cnf(formula_str)
            
#             # Handle single predicate (no logical operator)
#             pred_name, args = self.parse_predicate(formula_str)
#             if pred_name:
#                 pred_str = self.normalize_predicate(pred_name, args)
#                 symbol = self.get_symbol(pred_str)
#                 self.constants.update(args)
#                 return [[symbol]]
            
#             return []
#         except Exception:
#             # If parsing fails completely, return empty
#             return []
    
#     def parse_implies_to_cnf(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse Implies(A, B) and convert to CNF.
#         Implies(A, B) ≡ Or(Not(A), B) ≡ ~A ∨ B
#         """
#         try:
#             # Extract the content inside Implies(...)
#             if len(formula_str) < 10:  # "Implies()" is 10 chars
#                 return []
            
#             inner = formula_str[8:-1]  # Remove 'Implies(' and ')'
#             args = self.extract_top_level_args(inner)
            
#             if len(args) != 2:
#                 return []
            
#             antecedent = args[0].strip()
#             consequent = args[1].strip()
            
#             # Convert: Implies(A, B) → Or(Not(A), B)
#             # Build the Or formula
#             or_formula = f"Or(Not({antecedent}), {consequent})"
            
#             # Parse it as an Or
#             return self.parse_or_to_cnf(or_formula)
#         except Exception:
#             return []
    
#     def parse_and_to_cnf(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse And(...) and convert to CNF.
#         And is already a conjunction, so we convert each sub-formula to CNF
#         and combine all clauses.
#         """
#         try:
#             if len(formula_str) < 6:  # "And()" is 5 chars
#                 return []
            
#             # Extract all top-level arguments of And
#             inner = formula_str[4:-1]  # Remove 'And(' and ')'
#             sub_formulas = self.extract_top_level_args(inner)
            
#             all_clauses = []
#             for sub in sub_formulas:
#                 sub_cnf = self.parse_formula_to_cnf(sub)
#                 all_clauses.extend(sub_cnf)
            
#             return all_clauses
#         except Exception:
#             return []
    
#     def parse_or_to_cnf(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse Or(...) and convert to CNF.
#         Or is a disjunction - need to distribute if it contains conjunctions.
#         """
#         try:
#             if len(formula_str) < 5:  # "Or()" is 4 chars
#                 return []
            
#             # Extract all top-level arguments of Or
#             inner = formula_str[3:-1]  # Remove 'Or(' and ')'
#             sub_formulas = self.extract_top_level_args(inner)
            
#             # Convert each sub-formula to CNF
#             sub_cnfs = []
#             for sub in sub_formulas:
#                 sub_cnf = self.parse_formula_to_cnf(sub)
#                 if not sub_cnf:
#                     sub_cnf = [[]]  # Empty disjunct
#                 sub_cnfs.append(sub_cnf)
            
#             if len(sub_cnfs) == 0:
#                 return []
            
#             # Distribute: Or of CNFs
#             # Start with first CNF
#             result = sub_cnfs[0]
            
#             # Distribute with each subsequent CNF
#             for cnf in sub_cnfs[1:]:
#                 new_result = []
#                 for clause1 in result:
#                     for clause2 in cnf:
#                         new_result.append(clause1 + clause2)
#                 result = new_result
            
#             return result
#         except Exception:
#             return []
    
#     def parse_not_to_cnf(self, formula_str: str) -> List[List[str]]:
#         """Parse Not(...) and apply De Morgan's laws if needed."""
#         try:
#             if len(formula_str) < 6:  # "Not()" is 5 chars
#                 return []
            
#             # Extract what's inside the Not()
#             inner = formula_str[4:-1]  # Remove 'Not(' and ')'
#             inner = inner.strip()
            
#             if not inner:
#                 return []
            
#             # Apply De Morgan's laws
#             # Not(And(A, B, ...)) → Or(Not(A), Not(B), ...)
#             if inner.startswith('And('):
#                 and_inner = inner[4:-1]  # Remove 'And(' and ')'
#                 sub_formulas = self.extract_top_level_args(and_inner)
#                 # Negate each and create Or
#                 negated = [f"Not({sub})" for sub in sub_formulas]
#                 or_formula = f"Or({', '.join(negated)})"
#                 return self.parse_or_to_cnf(or_formula)
            
#             # Not(Or(A, B, ...)) → And(Not(A), Not(B), ...)
#             elif inner.startswith('Or('):
#                 or_inner = inner[3:-1]  # Remove 'Or(' and ')'
#                 sub_formulas = self.extract_top_level_args(or_inner)
#                 # Negate each and create And
#                 all_clauses = []
#                 for sub in sub_formulas:
#                     not_sub = f"Not({sub})"
#                     clauses = self.parse_formula_to_cnf(not_sub)
#                     all_clauses.extend(clauses)
#                 return all_clauses
            
#             # Not(Not(A)) → A (double negation elimination)
#             elif inner.startswith('Not('):
#                 double_neg_inner = inner[4:-1]  # Remove inner 'Not(' and ')'
#                 return self.parse_formula_to_cnf(double_neg_inner)
            
#             # Not(Implies(A, B)) → Not(Or(Not(A), B)) → And(Not(Not(A)), Not(B)) → And(A, Not(B))
#             elif inner.startswith('Implies('):
#                 implies_inner = inner[8:-1]
#                 args = self.extract_top_level_args(implies_inner)
#                 if len(args) == 2:
#                     # Not(Implies(A, B)) ≡ A ∧ Not(B)
#                     and_formula = f"And({args[0]}, Not({args[1]}))"
#                     return self.parse_and_to_cnf(and_formula)
            
#             # Not(predicate(...)) → single negated literal
#             else:
#                 pred_name, args = self.parse_predicate(inner)
#                 if pred_name:
#                     pred_str = self.normalize_predicate(pred_name, args)
#                     symbol = self.get_symbol(pred_str)
#                     self.constants.update(args)
#                     return [[f"~{symbol}"]]
            
#             return []
#         except Exception:
#             return []
    
#     def parse_exists_formula(self, formula_str: str) -> List[List[str]]:
#         """
#         Parse Exists formula: Exists:num_vars:body
#         Ground it with all constants and create a disjunction.
#         Exists([x], P(x)) ≡ P(c1) ∨ P(c2) ∨ ...
#         Exists([x], And(P(x), Q(x))) ≡ And(P(c1), Q(c1)) ∨ And(P(c2), Q(c2)) ∨ ...
#         """
#         parts = formula_str.split(':', 2)
#         if len(parts) != 3:
#             return []
        
#         num_vars = int(parts[1])
#         body = parts[2]
        
#         # Get constants to ground with
#         constants_list = list(self.constants) if self.constants else ['_skolem_constant']
        
#         # For a single variable (most common case)
#         if num_vars == 1:
#             # Create grounded versions of the body for each constant
#             grounded_bodies = []
#             for const in constants_list:
#                 # Replace variable 'x' with constant in the body
#                 # This is a simple string replacement approach
#                 grounded_body = body.replace('(x)', f'({const})').replace('(x,', f'({const},').replace(',x)', f',{const})')
#                 grounded_bodies.append(grounded_body)
            
#             # Now we have Exists([x], body) ≡ body[x:=c1] ∨ body[x:=c2] ∨ ...
#             # Create an Or of all grounded bodies
#             if len(grounded_bodies) == 1:
#                 # Just one constant, convert body directly
#                 return self.parse_formula_to_cnf(grounded_bodies[0])
#             else:
#                 # Multiple constants, create disjunction
#                 or_formula = f"Or({', '.join(grounded_bodies)})"
#                 return self.parse_or_to_cnf(or_formula)
        
#         # For multiple variables, we'd need to handle all combinations
#         # For now, just return empty (this is a simplification)
#         return []
    
#     def parse_z3_code(self, code: str):
#         """Parse Z3 code and extract facts, rules, and query."""
#         lines = code.strip().split('\n')
        
#         for line in lines:
#             line = line.strip()
#             if not line or line.startswith('#'):
#                 continue
                
#             # Check for ForAll (rules)
#             if line.startswith('ForAll'):
#                 self.parse_rule(line)
#             # Check for Exists - treat as a formula (will be grounded later)
#             elif line.startswith('Exists'):
#                 self.parse_exists(line)
#             # Check for return statement (query)
#             elif line.startswith('return'):
#                 query_match = re.search(r'return\s+(.+)', line)
#                 if query_match:
#                     # breakpoint()
#                     self.query = query_match.group(1).strip()
#             # Check for Implies at top level (ground implication)
#             elif line.startswith('Implies('):
#                 self.formulas.append(line)
#             # Check for complex formulas (Or, And, Not with predicates)
#             elif any(line.startswith(op) for op in ['Or(', 'And(', 'Not(']):
#                 self.formulas.append(line)
#             # Otherwise it's a simple fact
#             else:
#                 pred_name, args = self.parse_predicate(line)
#                 if pred_name:
#                     # Normalize and store
#                     normalized = self.normalize_predicate(pred_name, args)
#                     # Parse it back to get clean args
#                     pred_name, args = self.parse_predicate(normalized)
#                     self.facts.append((pred_name, args))
#                     self.constants.update(args)
#                     if pred_name not in self.predicates:
#                         self.predicates[pred_name] = len(args)
    
#     def parse_exists(self, line: str):
#         """
#         Parse an Exists statement.
#         Exists([x], formula) means there exists at least one x that satisfies the formula.
#         In CNF conversion, we can either:
#         1. Skolemize (replace with a constant)
#         2. Ground with all constants (instantiate for all possible values)
        
#         For simplicity, we'll ground it with all known constants.
#         """
#         # Extract variables
#         vars_match = re.search(r'Exists\(\[([^\]]+)\]', line)
#         if not vars_match:
#             return
        
#         variables = [v.strip() for v in vars_match.group(1).split(',')]
        
#         # Extract the body formula
#         # Find the comma after the variable list
#         bracket_end = line.find(']')
#         if bracket_end == -1:
#             return
        
#         # Find the content after the comma
#         rest = line[bracket_end+1:].strip()
#         if rest.startswith(','):
#             rest = rest[1:].strip()
        
#         # Remove the final closing paren
#         if rest.endswith(')'):
#             rest = rest[:-1].strip()
        
#         # Store as a special existential formula
#         # We'll treat it as a disjunction over all groundings
#         # Exists([x], P(x)) ≡ P(c1) ∨ P(c2) ∨ ... for all constants
#         self.formulas.append(f"Exists:{len(variables)}:{rest}")
    
#     def parse_rule(self, line: str):
#         """Parse a ForAll rule with better parsing."""
#         # Extract variables
#         vars_match = re.search(r'ForAll\(\[([^\]]+)\]', line)
#         if not vars_match:
#             return
        
#         variables = [v.strip() for v in vars_match.group(1).split(',')]
        
#         # Find the Implies or Or part
#         implies_match = re.search(r'Implies\((.*)\)\s*\)', line)
#         or_match = re.search(r'Or\((.*)\)\s*\)', line) if not implies_match else None
        
#         if implies_match:
#             self.parse_implies_rule(variables, implies_match.group(1))
#         elif or_match:
#             # ForAll([x], Or(...)) is already in clause form
#             self.parse_or_rule(variables, or_match.group(1))
    
#     def parse_or_rule(self, variables: List[str], or_body: str):
#         """
#         Parse ForAll([x], Or(a, b, c)) - already a clause.
#         But if the Or contains complex formulas, we need to convert to CNF first.
#         """
#         # Check if the or_body contains nested complex formulas (And, Or, Implies)
#         # If so, we need to handle it differently
        
#         # Extract top-level arguments of the Or
#         args = self.extract_top_level_args(or_body)
        
#         # Check if any argument is a complex formula
#         has_complex = any(
#             arg.strip().startswith(('And(', 'Or(', 'Implies('))
#             for arg in args
#         )
        
#         if has_complex:
#             # This is a complex rule that needs to be converted via CNF
#             # Store it as a special rule that will be handled during grounding
#             self.rules.append((variables, [], ('COMPLEX_OR', or_body)))
#             return
        
#         # Simple case: all arguments are literals (predicates or Not(predicates))
#         literals_info = []
        
#         for arg in args:
#             arg = arg.strip()
            
#             # Check if it's a negated predicate
#             if arg.startswith('Not('):
#                 # Extract the predicate inside Not(...)
#                 inner = arg[4:-1]  # Remove 'Not(' and ')'
#                 pred_name, pred_args = self.parse_predicate(inner)
#                 if pred_name:
#                     literals_info.append((pred_name, pred_args, True))  # True = negated
#                     if pred_name not in self.predicates:
#                         self.predicates[pred_name] = len(pred_args)
#             else:
#                 # Positive predicate
#                 pred_name, pred_args = self.parse_predicate(arg)
#                 if pred_name:
#                     literals_info.append((pred_name, pred_args, False))  # False = positive
#                     if pred_name not in self.predicates:
#                         self.predicates[pred_name] = len(pred_args)
        
#         # Store as a special kind of rule (disjunction)
#         self.rules.append((variables, [], literals_info))  # Empty antecedent, list of literals
    
#     def parse_implies_rule(self, variables: List[str], implies_body: str):
#         """Parse Implies(And(...), consequent) or Implies(pred, consequent)."""
#         # Split by comma at the top level
#         depth = 0
#         parts = []
#         current = []
        
#         for char in implies_body:
#             if char == '(':
#                 depth += 1
#                 current.append(char)
#             elif char == ')':
#                 depth -= 1
#                 current.append(char)
#             elif char == ',' and depth == 0:
#                 parts.append(''.join(current).strip())
#                 current = []
#             else:
#                 current.append(char)
        
#         if current:
#             parts.append(''.join(current).strip())
        
#         if len(parts) != 2:
#             return
        
#         antecedent_part = parts[0].strip()
#         consequent_part = parts[1].strip()
        
#         # Check if antecedent or consequent is complex
#         # If so, store the entire rule for later processing during grounding
#         has_complex_antecedent = any(
#             antecedent_part.startswith(op) for op in ['Or(', 'Implies(', 'Not(And', 'Not(Or']
#         )
#         has_complex_consequent = consequent_part.startswith(('Or(', 'And('))
        
#         if has_complex_antecedent or has_complex_consequent:
#             # Store as complex implies rule
#             self.rules.append((variables, [], ('COMPLEX_IMPLIES', implies_body)))
#             return
        
#         # Parse antecedent
#         antecedents = []
#         if antecedent_part.startswith('And('):
#             and_match = re.match(r'And\((.*)\)', antecedent_part)
#             if and_match:
#                 antecedents = self.extract_predicates_from_and(and_match.group(1))
#         else:
#             # Single predicate antecedent
#             pred_name, args = self.parse_predicate(antecedent_part)
#             if pred_name:
#                 antecedents = [(pred_name, args)]
        
#         # Track predicates
#         for pred_name, args in antecedents:
#             if pred_name not in self.predicates:
#                 self.predicates[pred_name] = len(args)
        
#         # Parse consequent - check if it's a complex formula
#         if consequent_part.startswith('Or('):
#             # Implies(A, Or(B, C)) ≡ ~A ∨ B ∨ C
#             # This is already in clause form!
#             # Extract all literals from the Or
#             or_literals = []
            
#             # Extract the content of Or(...)
#             or_match = re.match(r'Or\((.*)\)', consequent_part)
#             if or_match:
#                 or_content = or_match.group(1)
#                 or_predicates = self.extract_predicates_from_and(or_content)  # Reuse this helper
                
#                 for pred_name, args in or_predicates:
#                     or_literals.append((pred_name, args, False))  # False = positive
#                     if pred_name not in self.predicates:
#                         self.predicates[pred_name] = len(args)
                
#                 # Store as: antecedents lead to a disjunction
#                 # We'll convert: Implies(A, Or(B, C)) to Or(~A, B, C)
#                 # So store all antecedents as negated, plus all consequent literals as positive
#                 all_literals = []
#                 for pred_name, args in antecedents:
#                     all_literals.append((pred_name, args, True))  # True = negated (antecedent)
#                 all_literals.extend(or_literals)  # Add consequent literals (positive)
                
#                 self.rules.append((variables, [], all_literals))  # Store as disjunction form
#                 return
        
#         elif consequent_part.startswith('And('):
#             # Implies(A, And(B, C)) ≡ (~A ∨ B) ∧ (~A ∨ C)
#             # Need to create multiple clauses
#             and_match = re.match(r'And\((.*)\)', consequent_part)
#             if and_match:
#                 and_predicates = self.extract_predicates_from_and(and_match.group(1))
                
#                 # Create one clause for each consequent literal
#                 for cons_pred_name, cons_args in and_predicates:
#                     if cons_pred_name not in self.predicates:
#                         self.predicates[cons_pred_name] = len(cons_args)
                    
#                     # Each clause: ~A1 ∨ ~A2 ∨ ... ∨ Ci
#                     self.rules.append((variables, antecedents, (cons_pred_name, cons_args)))
#                 return
        
#         # Check for Not in consequent
#         if consequent_part.startswith('Not('):
#             # Extract what's inside Not(...)
#             not_match = re.match(r'Not\((\w+)\(([^)]+)\)\)', consequent_part)
#             if not_match:
#                 cons_pred_name = not_match.group(1)
#                 cons_args_str = not_match.group(2)
#                 cons_args = [arg.strip() for arg in cons_args_str.split(',')]
                
#                 if cons_pred_name not in self.predicates:
#                     self.predicates[cons_pred_name] = len(cons_args)
                
#                 # Store with negated consequent marker
#                 # We'll handle this specially in grounding
#                 self.rules.append((variables, antecedents, (cons_pred_name, cons_args, True)))
#                 return
        
#         # Simple consequent (single predicate)
#         consequent_name, consequent_args = self.parse_predicate(consequent_part)
#         if not consequent_name:
#             return
        
#         if consequent_name not in self.predicates:
#             self.predicates[consequent_name] = len(consequent_args)
        
#         self.rules.append((variables, antecedents, (consequent_name, consequent_args)))
    
#     def ground_predicate(self, pred_name: str, args: List[str], substitution: Dict[str, str]) -> str:
#         """Ground a predicate with substitution and return normalized form."""
#         grounded_args = [substitution.get(arg, arg) for arg in args]
#         return self.normalize_predicate(pred_name, grounded_args)
    
#     def ground_rules(self) -> List[List[str]]:
#         """Ground all rules with all possible constant substitutions."""
#         clauses = []
#         constants_list = list(self.constants)
        
#         for variables, antecedents, consequent in self.rules:
#             # Generate all possible substitutions
#             for substitution_values in product(constants_list, repeat=len(variables)):
#                 substitution = dict(zip(variables, substitution_values))
                
#                 # Check if this is a complex OR rule
#                 if isinstance(consequent, tuple) and len(consequent) == 2 and consequent[0] == 'COMPLEX_OR':
#                     # Ground the complex Or formula and convert to CNF
#                     or_body = consequent[1]
                    
#                     # Replace variables with constants in the formula
#                     grounded_formula = or_body
#                     for var, const in substitution.items():
#                         # Strip whitespace from variable name
#                         var = var.strip()
#                         # Use regex with word boundaries for precise replacement
#                         # This handles: (x), (x,, ,x), ,x,, etc.
#                         grounded_formula = re.sub(
#                             rf'\({var}\)',  # (x) -> (const)
#                             f'({const})',
#                             grounded_formula
#                         )
#                         grounded_formula = re.sub(
#                             rf'\({var},',  # (x, -> (const,
#                             f'({const},',
#                             grounded_formula
#                         )
#                         grounded_formula = re.sub(
#                             rf',{var}\)',  # ,x) -> ,const)
#                             f',{const})',
#                             grounded_formula
#                         )
#                         grounded_formula = re.sub(
#                             rf',{var},',  # ,x, -> ,const,
#                             f',{const},',
#                             grounded_formula
#                         )
                    
#                     # Now convert this grounded Or formula to CNF
#                     or_formula = f"Or({grounded_formula})"
#                     cnf_clauses = self.parse_or_to_cnf(or_formula)
#                     clauses.extend(cnf_clauses)
#                     continue
                
#                 # Check if this is a disjunction rule (empty antecedents, consequent is list)
#                 if len(antecedents) == 0 and isinstance(consequent, list):
#                     # This is ForAll([x], Or(...)) form with simple literals
#                     clause = []
#                     for pred_name, args, is_negated in consequent:
#                         grounded = self.ground_predicate(pred_name, args, substitution)
#                         symbol = self.get_symbol(grounded)
#                         if is_negated:
#                             clause.append(f"~{symbol}")
#                         else:
#                             clause.append(symbol)
#                     clauses.append(clause)
#                 else:
#                     # Standard Implies rule: ~A1 v ~A2 v ... v C
#                     clause = []
                    
#                     # Add negated antecedents
#                     for pred, args in antecedents:
#                         grounded = self.ground_predicate(pred, args, substitution)
#                         clause.append(f"~{self.get_symbol(grounded)}")
                    
#                     # Check if consequent is negated (tuple with 3 elements)
#                     if isinstance(consequent, tuple) and len(consequent) == 3:
#                         # Negated consequent
#                         grounded_consequent = self.ground_predicate(consequent[0], consequent[1], substitution)
#                         clause.append(f"~{self.get_symbol(grounded_consequent)}")
#                     else:
#                         # Positive consequent
#                         grounded_consequent = self.ground_predicate(consequent[0], consequent[1], substitution)
#                         clause.append(self.get_symbol(grounded_consequent))
                    
#                     clauses.append(clause)
        
#         return clauses
    
#     def convert_to_cnf(self, negate_query: bool = False) -> Tuple[List[List[str]], Dict[str, str]]:
#         """
#         Convert the FOL problem to CNF.
        
#         Args:
#             negate_query: If True, negate the query (to prove it's true via contradiction)
#                          If False, assert the query (to prove it's false via contradiction)
        
#         Returns:
#             Tuple of (clauses, mappings)
#         """
#         clauses = []
        
#         # Add facts as unit clauses
#         for pred_name, args in self.facts:
#             fact_str = self.normalize_predicate(pred_name, args)
#             symbol = self.get_symbol(fact_str)
#             clauses.append([symbol])
        
#         # Add grounded rules
#         clauses.extend(self.ground_rules())
        
#         # Add non-quantified formulas converted to CNF
#         for formula in self.formulas:
#             formula_cnf = self.parse_formula_to_cnf(formula)
#             clauses.extend(formula_cnf)
        
#         # Add query (negated for proving true, asserted for proving false)
#         if self.query:
#             # Convert query to CNF
#             query_cnf = self.parse_formula_to_cnf(self.query)
            
#             if negate_query:
#                 # Negate the entire query
#                 # If query is a single literal, just negate it
#                 if len(query_cnf) == 1 and len(query_cnf[0]) == 1:
#                     lit = query_cnf[0][0]
#                     if lit.startswith('~'):
#                         clauses.append([lit[1:]])
#                     else:
#                         clauses.append([f"~{lit}"])
#                 else:
#                     # For complex queries, we need to negate the whole formula
#                     # Convert query back to formula, wrap in Not(), then convert to CNF
#                     negated_query = f"Not({self.query})"
#                     negated_cnf = self.parse_formula_to_cnf(negated_query)
#                     clauses.extend(negated_cnf)
#             else:
#                 clauses.extend(query_cnf)
        
#         return clauses, dict(self.mappings)

        
def write_dimacs_cnf(clauses: List[List[str]], filename: str):
    """
    Write CNF clauses to a DIMACS format file.
    DIMACS format:
    - Comments start with 'c'
    - Problem line: p cnf <num_vars> <num_clauses>
    - Each clause is a line of space-separated integers ending with 0
    - Positive integer n represents variable n
    - Negative integer -n represents negation of variable n
    """
    # Create a mapping from propositional symbols (p0, p1, ...) to integers (1, 2, ...)
    vs = set()
    new_clauses = []
    for clause in clauses:
        if clause == []: continue
        else: new_clauses.append(clause)
        for l in clause:
            if l.startswith('~'):
                l = l[1:]
            l = l.strip('p')
            vs.add(int(l))
    clauses = new_clauses
    # if 0 in vs: breakpoint()
    if len(vs) > 0:
        num_vars = np.max(list(vs)) + 1
    else: 
        num_vars = 1
    num_clauses = len(clauses)
    # breakpoint()
    with open(filename, 'w') as f:
        # Write problem line
        f.write(f"p cnf {num_vars} {num_clauses}\n")
        # if len(clauses) != num_clauses: breakpoint()
        # Write clauses
        for clause in clauses:
            if clause == []: continue
            dimacs_clause = []
            for literal in clause:
                if literal.startswith('~'):
                    # Negated literal
                    symbol = literal[1:]
                    dimacs_clause.append(-(int(symbol.strip('p'))+1))
                else:
                    # Positive literal
                    dimacs_clause.append(int(literal.strip('p'))+1)
            
            # Write clause as space-separated integers ending with 0
            f.write(' '.join(map(str, dimacs_clause)) + ' 0\n')

def write_mapping_file(mappings: Dict[str, str], filename: str):
    """
    Write the mappings from propositional symbols to predicates.
    Format: symbol -> predicate
    """
    with open(filename, 'w') as f:
        # Sort by symbol number for readability
        sorted_mappings = sorted(mappings.items(), key=lambda x: int(x[1].strip('p')))
        for predicate, symbol in sorted_mappings:
            f.write(f"{symbol} -> {predicate}\n")

def fol_to_cnf(z3_code: str) -> Dict:
    """
    Convert a Z3 FOL problem to CNF.
    
    Returns a dictionary with:
        - cnf_prove_true: CNF to test if query is True (query negated)
        - cnf_prove_false: CNF to test if query is False (query asserted)
        - mappings: Dictionary mapping symbols to predicates
        - reverse_mappings: Dictionary mapping predicates to symbols
    """
    converter = FOLToCNF()
    converter.parse_z3_code(z3_code)
    
    # print(f"\nParsed {len(converter.facts)} facts")
    # print(f"Parsed {len(converter.rules)} rules")
    # print(f"Parsed {len(converter.formulas)} formulas")
    # print(f"Query: {converter.query}")
    
    # Generate CNF for proving query is TRUE (negate query, look for contradiction)
    cnf_true, mappings = converter.convert_to_cnf(negate_query=True)
    
    # Generate CNF for proving query is FALSE (assert query, look for contradiction)
    cnf_false, _ = converter.convert_to_cnf(negate_query=False)
    
    return {
        'pos': cnf_true,
        'neg': cnf_false,
        'mappings': mappings,
        'reverse_mappings': converter.reverse_mappings
    }


def add_clause(file):
    f = open(file, 'r')
    lines = f.readlines()
    writestr = ''
    for line in lines:
        if line.startswith('p cnf'):
            num_var, num_clause = [line.split(' ')[2], line.split(' ')[3]]
            writestr += 'p cnf ' + str(num_var) + ' ' + str(int(num_clause) + 1) + '\n'
        else:
            writestr += line
    f.close()

    f = open(file, 'w')
    f.write(writestr)
    f.close()
def add_var(file):
    f = open(file, 'r')
    lines = f.readlines()
    writestr = ''
    for line in lines:
        if line.startswith('p cnf'):
            num_var, num_clause = [line.split(' ')[2], line.split(' ')[3]]
            writestr += 'p cnf ' + str(int(num_var) +1) + ' ' + str(int(num_clause) ) + '\n'
        else:
            writestr += line
    f.close()

    f = open(file, 'w')
    f.write(writestr)
    f.close()
def save_pattern(van, vap, vbn, vbp, vcn, vcp, patterns):
    mapping = {}
    if not isinstance(van, list):
        van = list(van)

    if not isinstance(vbn, list):
        vbn = list(vbn)

    if not isinstance(vcn, list):
        vcn = list(vcn)
    names = van + vbn + vcn
    i = 0
    for name in names:
        if name not in mapping.keys():
            mapping[name] = str(i)
            i += 1
        

    patterns.append(([mapping[name] for name in van], vap, [mapping[name] for name in vbn], vbp, [mapping[name] for name in vcn], vcp))
    return patterns


def check_pattern(van, vap, vbn, vbp, vcn, vcp, patterns):
    mapping = {}
    if not isinstance(van, list):
        van = list(van)

    if not isinstance(vbn, list):
        vbn = list(vbn)

    if not isinstance(vcn, list):
        vcn = list(vcn)

    names = van + vbn + vcn
    i = 0
    for name in names:
        if name not in mapping.keys():
            mapping[name] = str(i)
            i += 1

    pattern = ([mapping[name] for name in van], vap, [mapping[name] for name in vbn], vbp, [mapping[name] for name in vcn], vcp)

    for pat in patterns:
        if pat == pattern:
            return True
        
    return False

def search_pattern(van, vap, vbn, vbp, patterns):
    mapping = {}
    if not isinstance(van, list):
        van = list(van)

    if not isinstance(vbn, list):
        vbn = list(vbn)


    names = van + vbn
    unmapping = {}
    i = 0
    for name in names:
        if name not in mapping.keys():
            mapping[name] = str(i)
            unmapping[str(i)] = name
            i += 1
    pattern = ([mapping[name] for name in van], vap, [mapping[name] for name in vbn], vbp)

    for pat in patterns:
        ant = (pat[0], pat[1], pat[2], pat[3])
        if ant == pattern:
            # breakpoint()
            try:
                return ([unmapping[name] for name in pat[4]], pat[5])
            except:
                continue
    return None

                
import math
def get_sol(file, lim=10000, del_sols=None, seedrun=0):
    
    solutions = {'pos':  [], 'neg': []}
    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]
    for i in range(len(files)):
        file = files[i]
        shutil.copy(file, '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
        #   
        if not del_sols==None:
            if 'pos' in file:
                if 'neg' in file:
                    print('l. 343 uh oh')
                      
                ds = del_sols['pos']
            elif 'neg' in file:
                ds = del_sols['neg']
            for sol in ds:
                add_clause('/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/'+ str(file.split('/')[-1]))
                cf = open('/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/'+ str(file.split('/')[-1]), 'a')
                write_str = '\n'
                for lit in sol:
                    write_str += str(-lit) + ' '
                # write_str += '0'
                cf.write(write_str)
                cf.close()
        count = 0
        while True:
            count += 1
            if count > lim:
                break
            os.system(USER_PATH + '/sat_gen/HardSATGEN/postprocess/cadical/build/cadical ' + '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]) + '> ' + '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.log')
            
            cf = open( '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.log', 'r')

            lines = cf.readlines()

            el = lines[-1]
            # print(el)
            try:
                ec = el.split('exit ')[1].strip('\n')
            except:
                breakpoint()
            # lf.close()
            if ec == '20':
                break
            sl = lines[1:]
            while not sl[0].startswith('s '):
                sl = sl[1:]
            sl = sl[1:]
            solution = []
            while sl[0].startswith('v '):
                solution += list(map(int, sl[0].strip('\n').split(' ')[1:]))
                sl = sl[1:]
            if 'pos' in file:
                if 'neg' in file:
                    print('l. 384 uh oh')
                      
                solutions['pos'].append(solution)
            elif 'neg' in file:
                solutions['neg'].append(solution)
            cf.close()
            #   
            add_clause('/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
            cf = open('/'.join(file.split('/')[:-2]) +'/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]), 'a')
            write_str = '\n'
            for lit in solution:
                write_str += str(-lit) + ' '
            # write_str += '0'
            cf.write(write_str)
            cf.close()
            #   
        # solutions['pos'] = np.stack(solutions['pos'])
        # solutions['neg'] = np.stack(solutions['neg'])
    

    return solutions
    
def score(context, step, mode = 'conf', cheat_idx = None):
    # ftc = fol_to_cnf(step)
    # if ftc['pos'] == []: 
    #     return -1000000
    mc_t = get_mc(context + step)
    mc_tminus = get_mc(context)
    
    if mode == 'conf':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else: mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))

        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: mc_tminus = np.max(list(mc_tminus.values()))/np.sum(list(mc_tminus.values()))
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_t - mc_tminus
    elif mode == 'raw-bb':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        bb = get_bb(context + step)
        bb_tminus = get_bb(context)

        # n_var = len(fol_to_cnf(context + step)['mappings'])
        result = fol_to_cnf(context+step)
        result_tminus = fol_to_cnf(context)
        n_var = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var.append(np.max(list(vs)) + 1)
        if len(n_var) == 0: n_var=0
        else: n_var = max(n_var)

        n_var_tminus = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result_tminus[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var_tminus.append(np.max(list(vs)) + 1)
        if len(n_var_tminus) == 0: n_var_tminus=0
        else: n_var_tminus = max(n_var_tminus)

        
        # n_var_tminus = len(fol_to_cnf(context)['mappings'])

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else:
            max_key = max(mc_t, key=mc_t.get)
            min_key = min(mc_t, key=mc_t.get)
            bb_l = len(bb[max_key])
            # mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))
            if mc_t[max_key] == 0: return -1000000
            if (mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1)) == 0: breakpoint()
            mc_t = (mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1))
            
        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: 
            
            max_key = max(mc_tminus, key=mc_tminus.get)
            min_key = min(mc_tminus, key=mc_tminus.get)
            
            bb_l = len(bb_tminus[max_key])
            mc_tminus = (mc_tminus[min_key]*(n_var_tminus-len(bb_tminus[min_key])+1)+mc_tminus[max_key]*(n_var_tminus-bb_l+1))
        
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_tminus - mc_t
    elif mode == 'conf-bb-norm':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        bb = get_bb(context + step)
        bb_tminus = get_bb(context)

        # n_var = len(fol_to_cnf(context + step)['mappings'])
        result = fol_to_cnf(context+step)
        result_tminus = fol_to_cnf(context)
        n_var = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var.append(np.max(list(vs)) + 1)
        n_var = max(n_var)

        n_var_tminus = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result_tminus[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var_tminus.append(np.max(list(vs)) + 1)
        n_var_tminus = max(n_var_tminus)

        
        # n_var_tminus = len(fol_to_cnf(context)['mappings'])

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else:
            max_key = max(mc_t, key=mc_t.get)
            min_key = min(mc_t, key=mc_t.get)
            bb_l = len(bb[max_key])
            # mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))
            if mc_t[max_key] == 0: return -1000000
            if (mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1)) == 0: breakpoint()
            mc_t = mc_t[max_key]*(n_var-bb_l+1)/(mc_t[min_key]*(n_var-len(bb[min_key])+1)+mc_t[max_key]*(n_var-bb_l+1))
            
        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: 
            
            max_key = max(mc_tminus, key=mc_tminus.get)
            min_key = min(mc_tminus, key=mc_tminus.get)
            
            bb_l = len(bb_tminus[max_key])
            mc_tminus = mc_tminus[max_key]*(n_var_tminus-bb_l+1)/(mc_tminus[min_key]*(n_var_tminus-len(bb_tminus[min_key])+1)+mc_tminus[max_key]*(n_var_tminus-bb_l+1))
        
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_t -  mc_tminus
    elif mode == 'conf-bb':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        bb = get_bb(context + step)
        bb_tminus = get_bb(context)

        # n_var = len(fol_to_cnf(context + step)['mappings'])
        # n_var_tminus = len(fol_to_cnf(context)['mappings'])
        result = fol_to_cnf(context+step)
        result_tminus = fol_to_cnf(context)
        n_var = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var.append(np.max(list(vs)) + 1)
        n_var = max(n_var)

        n_var_tminus = []
        for n in ['pos', 'neg']:
            vs = set()
            clauses = result_tminus[n]
            for clause in clauses:
                for l in clause:
                    if l.startswith('~'):
                        l = l[1:]
                    l = l.strip('p')
                    vs.add(int(l))
            # if 0 in vs: breakpoint()
            if len(vs) > 0:
                n_var_tminus.append(np.max(list(vs)) + 1)
        n_var_tminus = max(n_var_tminus)


        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else:
            max_key = max(mc_t, key=mc_t.get)
            bb_l = len(bb[max_key])
            # mc_t = np.max(list(mc_t.values()))/np.sum(list(mc_t.values()))
            mc_t = mc_t[max_key]*(n_var-bb_l)

        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: 

            max_key = max(mc_tminus, key=mc_tminus.get)
            bb_l = len(bb_tminus[max_key])
            mc_tminus = mc_tminus[max_key]*(n_var_tminus-bb_l)
        
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_t -  mc_tminus
    elif mode == 'raw-conf':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        else: mc_t = np.min(list(mc_t.values()))

        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        else: mc_tminus = np.min(list(mc_tminus.values()))
        # if np.sum(list(mc_t.values())) == 0: return 0
        return mc_tminus - mc_t

    elif mode == 'raw':
        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000

        if np.sum(list(mc_t.values())) <= 0: return -1000000
        if np.sum(list(mc_tminus.values())) <= 0: return -1000000
        mc_t = np.sum(list(mc_t.values()))
        mc_tminus = np.sum(list(mc_tminus.values()))
        return mc_tminus - mc_t
    elif mode == 'even':
        if np.sum(list(mc_t.values())) == 0 or mc_t['pos'] == -1 or mc_t['neg'] == -1: return -1000000
        if np.sum(list(mc_tminus.values())) == 0 or mc_tminus['pos'] == -1 or mc_tminus['neg'] == -1: return -1000000

        if -1000000 in mc_t.values() or -1000000 in mc_tminus.values(): return -1000000
        

        return -(np.abs(mc_t['pos'] - mc_t['neg']))
    elif mode == 'raw-cheating':
        mc_t = mc_t[['neg', 'pos'][cheat_idx]]
        mc_tminus = mc_tminus[['neg', 'pos'][cheat_idx]]
        return mc_tminus - mc_t

    elif mode == 'conf-cheating':
        mc_t = mc_t[['pos', 'neg'][cheat_idx]]/np.sum(list(mc_t.values()))
        mc_tminus = mc_tminus[['pos', 'neg'][cheat_idx]]/np.sum(list(mc_tminus.values()))
        return mc_tminus - mc_t

    # elif mode == 'rawconf:'
    # return get_mc(context + step) - get_mc(context) 
    # return mc_t - mc_tminus
    # return -get_mc(context + step)

def get_mc(text, seedrun=0):
    result = fol_to_cnf(text)
    # if [] in result['pos'] or [] in result['neg']: breakpoint()
    # if result['pos']
    # breakpoint()
    hash = str(time.time())
    for n in ['pos', 'neg']:
        # if
        output_file = '/home/XXXX/clause/folio_newlogic_ /getmc/' + n + '_' + hash + '.cnf'
        # outtext = ''
        file = output_file
        
        write_dimacs_cnf(result[n], output_file)
        o = open(output_file, 'r')
        rl = o.readlines()
        # while rl[0].startswith('')
        n_vars = int(rl[0].split(' ')[2])
        ls = []
        maxlen = 0
        
        for line in rl[1:]:
            try: 
                for lit in line.split(' ')[:-1]:
                    ls.append(np.abs(int(lit)))
                    if ls[-1] > n_vars: n_vars = ls[-1]
            except: breakpoint()
            clauselength = len(line.split(' '))-1
            if clauselength > maxlen: maxlen = clauselength
        # if maxlen < 2: breakpoint()

        
        # o = open(output_file, 'w')
        # o.write(result[n])
        # o.close()

        # output_file = '/home/XXXX/clause/folio_newlogic_ /getmc/' + n + '_' + hash + '.maptxt'
        # try:
        #     o = open(output_file, 'w')
        #     o.write(str(result['mappings']))
        #     o.close()
        # except: breakpoint()
    file = '/home/XXXX/clause/folio_newlogic_ /getmc/' + hash + '.cnf'
    # file = output_file
    # file = '/home/XXXX/clause/folio_newlogic_ /getmc/' + str(time.time()) + '_' + key + '.cnf'
    mc = {'pos':  -1, 'neg': -1}
    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]
    # breakpoint()
    for i in range(len(files)):
        file = files[i]
        try: shutil.copy(file, '/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
        except: breakpoint()
        os.remove(file)
        os.system('timeout 500 /home/XXXX/ganak --appmct 100 --mode 1 '  + '/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]) + '> ' + '/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.mc')
        
        cf = open('/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.mc', 'r')
        # breakpoint()
        lines = cf.readlines()

        el = lines[-2]
        # print(el)
        try:
            ec = el.split('exact arb frac ')[1].strip('\n')
        except:
            print('timeout')
            ec = -1000000
            # breakpoint()
        # breakpoint()
        # breakpoint()
        # lf.close()
        # if ec == '20':
        #     break
        # sl = lines[1:]
        if 'pos' in file: mc['pos'] = int(ec)
        elif 'neg' in file: mc['neg'] = int(ec)
        os.remove('/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + '.mc')
        
    # for value in mc.values():
    #     if value==1: breakpoint()
    return mc
    
#     return mc
# def get_mc(file, seedrun=0):
#     from concurrent.futures import ThreadPoolExecutor
#     import shutil
#     import os

#     def process_file(file, seedrun):
#         """Process a single file and return its model count"""
#         # Copy file to temp directory
#         temp_dir = '/'.join(file.split('/')[:-2]) + '/tempfiles' + str(seedrun)
#         temp_file = temp_dir + '/' + file.split('/')[-1]
#         shutil.copy(file, temp_file)
        
#         # Run ganak
#         output_file = temp_file[:-4] + '.mc'
#         os.system(f'timeout 500 /home/XXXX/ganak --appmct 100 --mode 1 {temp_file} > {output_file}')
        
#         # Read result
#         try:
#             with open(output_file, 'r') as cf:
#                 lines = cf.readlines()
#                 el = lines[-2]
#                 ec = el.split('exact arb frac ')[1].strip('\n')
#                 return int(ec)
#         except:
#             print('timeout')
#             return -1

#     # Prepare file paths
#     base_path = '/'.join(file.split('/')[:-1])
#     filename = file.split('/')[-1]
#     files = {
#         'pos': base_path + '/pos_' + filename,
#         'neg': base_path + '/neg_' + filename
#     }

#     # Process files in parallel
#     mc = {'pos': -1, 'neg': -1}
#     with ThreadPoolExecutor(max_workers=2) as executor:
#         futures = {key: executor.submit(process_file, filepath, seedrun) 
#                 for key, filepath in files.items()}
        
#         for key, future in futures.items():
#             mc[key] = future.result()

#     # Check for value of 1
#     # for value in mc.values():
#     #     if value == 1:
#     #         breakpoint()

#     return mc


    
def get_bb(text, del_sols=None, seedrun=0, ablate=False):

    bb = {'pos':  [], 'neg': []}



    result = fol_to_cnf(text)
    # if [] in result['pos'] or [] in result['neg']: breakpoint()
    # if result['pos']
    # breakpoint()
    hash = str(time.time())
    for n in ['pos', 'neg']:
        # if
        output_file = '/home/XXXX/clause/folio_newlogic_ /getbb/' + n + '_' + hash + '.cnf'
        # outtext = ''
        file = output_file
        
        write_dimacs_cnf(result[n], output_file)
        o = open(output_file, 'r')
        rl = o.readlines()
        # while rl[0].startswith('')
        n_vars = int(rl[0].split(' ')[2])
        ls = []
        maxlen = 0
        
        for line in rl[1:]:
            try: 
                for lit in line.split(' ')[:-1]:
                    ls.append(np.abs(int(lit)))
                    if ls[-1] > n_vars: n_vars = ls[-1]
            except: breakpoint()
            clauselength = len(line.split(' '))-1
            if clauselength > maxlen: maxlen = clauselength
        # if maxlen < 2: breakpoint()

        
        # o = open(output_file, 'w')
        # o.write(result[n])
        # o.close()

        # output_file = '/home/XXXX/clause/folio_newlogic_ /getmc/' + n + '_' + hash + '.maptxt'
        # try:
        #     o = open(output_file, 'w')
        #     o.write(str(result['mappings']))
        #     o.close()
        # except: breakpoint()
    file = '/home/XXXX/clause/folio_newlogic_ /getbb/' + hash + '.cnf'
    # file = output_file
    # file = '/home/XXXX/clause/folio_newlogic_ /getmc/' + str(time.time()) + '_' + key + '.cnf'
    mc = {'pos':  -1, 'neg': -1}
    files = ['/'.join(file.split('/')[:-1]) + '/pos_' + file.split('/')[-1], '/'.join(file.split('/')[:-1]) + '/neg_' + file.split('/')[-1] ]
    # breakpoint()
    for i in range(len(files)):
        file = files[i]
        try: shutil.copy(file, '/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]))
        except: breakpoint()
        os.remove(file)
        os.system("timeout 5000 " + USER_PATH + "/LLM-project/cadiback/cadiback " + '/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1]) + '> '  + '/work/XXXX/'+ '/tempfiles' + str(seedrun) + '/'+ str(file.split('/')[-1])[:-4] + ".bbone")
        #   
        bbone= open('/work/XXXX/' + '/tempfiles' + str(seedrun) + '/' + str(file.split('/')[-1])[:-4] + ".bbone", 'r')
        lines = bbone.readlines()
        # breakpoint()
        for line in lines:
            if line.startswith('b'):
                #   
                lits = line.split(' ')[1:]
                for lit in lits:
                    lit = lit.strip()
                    if lit == '0':
                        continue
                    lit = int(lit)
                    if 'pos' in file:                                
                        if 'neg' in file:
                            print('l. 447 uh oh')
                              
                        bb['pos'].append(lit)
                    elif 'neg' in file:
                            bb['neg'].append(lit)
    # breakpoint()
    # if bb['pos'] == 0 or bb['neg'] == 0: breakpoint()

    return bb

import os, shutil, json, random, subprocess, uuid
import numpy as np
from concurrent.futures import ProcessPoolExecutor

# USER_PATH = "/your/path"   # same as before

def make_tempdir(base_dir):
    tdir = os.path.join(base_dir, f"temp_{uuid.uuid4().hex}")
    os.makedirs(tdir, exist_ok=True)
    return tdir

def run_cadical(infile):
    logfile = infile[:-4] + ".log"
    cmd = [
        f"{USER_PATH}/sat_gen/HardSATGEN/postprocess/cadical/build/cadical",
        infile
    ]
    with open(logfile, "w") as logf:
        subprocess.run(" ".join(cmd), shell=True, stdout=logf, stderr=subprocess.STDOUT)

    with open(logfile, "r") as f:
        lines = f.readlines()
    ec = lines[-1].split("exit ")[-1].strip()
    return int(ec)

def random_clause_worker(input_file):
    # Create isolated directory for this worker
    base_dir = os.path.dirname(input_file)
    tdir = make_tempdir(base_dir)

    # Build positive/negative copies
    pos = os.path.join(tdir, "pos.cnf")
    neg = os.path.join(tdir, "neg.cnf")

    shutil.copy(os.path.join(base_dir, "pos_" + os.path.basename(input_file)), pos)
    shutil.copy(os.path.join(base_dir, "neg_" + os.path.basename(input_file)), neg)

    cnfs = [pos, neg]
    loopcounts = [0, 0]

    # Read number of variables
    with open(pos, "r") as cf:
        header = cf.readline().strip().split()
        n_var = int(header[-2])

    while True:
        clause = [str((1 if random.random() > 0.5 else -1) * random.randint(1, n_var)) for _ in range(3)]
        sat = [1, 1]

        for i, f in enumerate(cnfs):
            exitcode = run_cadical(f)
            if exitcode == 20:
                sat[i] = 0
            
            loopcounts[i] += 1
            if loopcounts[i] > 1000:
                return -1, [open(c).readlines() for c in cnfs], loopcounts[i]
            
            with open(f, "a") as cf:
                cf.write("\n" + " ".join(clause) + " 0")
                
            add_clause(f)
        if 0 in sat:
            break

    if sat == [0,0] or sat == [1,1]:
        breakflag = -1
    else:
        breakflag = int(np.argmin(sat))

    logics = [open(c).readlines() for c in cnfs]
    return breakflag, logics, loopcounts
def random_search(file, n=100):
    results = []
    with ProcessPoolExecutor(max_workers=72) as ex:
        futures = [ex.submit(random_clause_worker, file) for _ in range(n)]
        for f in futures:
            results.append(f.result())

    votes = [r[0] for r in results]
    logics = [r[1] for r in results]
    loopcounts = [r[2] for r in results]

    # Load mapping (unchanged from your code)
    # em = '/'.join(file[:-4].split('/')[:-1]) + '/neg_' + file[:-4].split('/')[-1] + '.maptxt'
    # maptxt = open(em, 'r').read()

    # # arity_file = '/'.join(file[:-4].split('/')[:-1]) + '/neg_' + file[:-4].split('/')[-1] + '.arity'
    # # arity1 = np.load(open(arity_file, 'rb'), allow_pickle=True).item()

    # # arity = {k.lower(): v for k, v in arity1.items()}
    # maptxt = maptxt.replace(" ", " \"").replace(",", "\",").replace(":", "\":").replace("{", "{\"").replace("}", "\"}")
    # mapping = json.loads(maptxt)
    mapping = None

    return votes, logics, mapping, loopcounts

def beam_search_cot(
    prompt: str,
    llm,
    score_fn: Callable[[str, str, str], float],
    beam_width: int = 10,
    num_samples: int = 5,
    max_steps: int = 10,
    stop_token: str = "Final Answer: ",
    mode = 'conf', 
    cheat_idx=None
):
    """
    Beam search over chain-of-thought steps.

    Args:
        prompt: Base problem prompt.
        llm: Object exposing llm.complete(prompt, n) -> List[str]
        score_fn: Function(context, next_step) -> float
        beam_width: Beam size.
        num_samples: Number of next-step samples per beam element.
        max_steps: Maximum CoT length.
        stop_token: Token indicating completion.

    Returns:
        Best beam element (dict with steps and score).
    """

    # Each beam item: {"steps": List[str], "score": float}
    beam = [{"steps": [], "score": 0.0, "scorelist": [], 'mc':{}}]
    if mode == 'conf-bb': beam[0]['score'] = int(0)
    completeds = []

    
    og_mc = get_mc(prompt)
    for step_idx in range(max_steps):
        candidates = []
        
        for item in beam:
            cot_so_far = prompt + '\n' +"\n".join(item["steps"])
            if 'completed' in item.keys() and item['completed'] == True: continue
            # generation_prompt = (
            #     f"{prompt}\n\n"
            #     f"Chain of thought so far:\n{cot_so_far}\n\n"
            #     f"Next step:"
            # 

            next_steps = []
            generation_prompt = FEWSHOT + cot_so_far + '\n#' + str(step_idx+1) + '. '
            if num_samples > 20:
                i = 0
                completions = []
                while i < num_samples:
                    completions += llm.generate_next_thought(generation_prompt, step_idx = step_idx+1, n_samples=20, max_new=100)
                    i += 20
            else:
                completions = llm.generate_next_thought(generation_prompt, step_idx = step_idx+1, n_samples=num_samples, max_new=100)

            for i in range(num_samples):
                if FEWSHOT not in completions[i]: breakpoint()
                completion = completions[i].split(FEWSHOT)[1]
                # breakpoint()
                try: next_steps.append('\n#' + str(step_idx+1) + '.' + completion.split(str(step_idx+1) + '.')[1].split('#' + str(step_idx+1+1) + '.')[0].strip('\n'))
                except: breakpoint()
            # breakpoint()
            for step in next_steps:
                step = step.strip()

                # Optional early stop
                if 'return True' in step or 'return False' in step or stop_token in step:
                    # breakpoint()
                    new_score = item["score"] 
                    candidates.append({
                        'og_mc': og_mc,
                        "steps": item["steps"] + [step],
                        "score": new_score,
                        'scorelist': item['scorelist'] + [new_score],
                        'mc': get_mc(cot_so_far),
                        "completed": True
                    })
                    completeds.append(candidates[-1])
                    continue
                # breakpoint()
                step_score = score_fn(cot_so_far, step, mode, cheat_idx=cheat_idx)
                # if step_score > 1000000: breakpoint()

                candidates.append({
                    "og_mc": og_mc,
                    "steps": item["steps"] + [step],
                    "score": item["score"] + step_score,
                    'scorelist': item['scorelist'] + [step_score],
                    "completed": False
                })

        # Sort by score (descending)
        candidates.sort(key=lambda x: x["score"], reverse=True)

        # Prune
        beam = candidates[:beam_width]

        # Optional: stop if all beams are completed
        if all(item.get("completed", False) for item in beam):
            break

    return beam + completeds
def exhaustive_search_cot(
    prompt: str,
    llm,
    score_fn: Callable[[str, str, str], float],
    branches_per_step: int = 3,
    max_steps: int = 10,
    stop_token: str = "Final Answer: ",
    mode: str = 'conf',
    cheat_idx=None,
    max_paths: int = 1000  # Safety limit to prevent explosion
):
    """
    Exhaustive search over chain-of-thought steps.
    Explores all branches up to a configurable branching factor per step.
    
    Args:
        prompt: Base problem prompt.
        llm: Object exposing llm.complete(prompt, n) -> List[str]
        score_fn: Function(context, next_step, mode) -> float
        branches_per_step: Number of branches to explore at each step.
        max_steps: Maximum CoT length.
        stop_token: Token indicating completion.
        mode: Scoring mode.
        cheat_idx: Optional parameter for score_fn.
        max_paths: Maximum total paths to prevent memory explosion.
        
    Returns:
        List of all explored paths (dicts with steps and scores).
    """
    # Each path item: {"steps": List[str], "score": float, "completed": bool}
    active_paths = [{"steps": [], "score": 0.0, "scorelist": [], "mc": {}, "completed": False}]
    completed_paths = []
    og_mc = get_mc(prompt)
    
    for step_idx in range(max_steps):
        new_active_paths = []
        
        for path in active_paths:
            # Skip already completed paths
            if path.get("completed", False):
                completed_paths.append(path)
                continue
            
            # Check if we've hit the safety limit
            if len(new_active_paths) + len(completed_paths) >= max_paths:
                print(f"Warning: Reached max_paths limit of {max_paths}")
                completed_paths.extend([path])
                continue
            
            cot_so_far = prompt + '\n' + "\n".join(path["steps"])
            generation_prompt = FEWSHOT + cot_so_far + '\n#' + str(step_idx + 1) + '. '
            
            # Generate branches for this path
            completions = llm.generate_next_thought(
                generation_prompt, 
                step_idx=step_idx + 1, 
                n_samples=branches_per_step, 
                max_new=100
            )
            
            # Process each completion
            next_steps = []
            for i in range(branches_per_step):
                if FEWSHOT not in completions[i]:
                    breakpoint()
                completion = completions[i].split(FEWSHOT)[1]
                
                try:
                    step = '\n#' + str(step_idx + 1) + '.' + \
                           completion.split(str(step_idx + 1) + '.')[1].split('#' + str(step_idx + 2) + '.')[0].strip('\n')
                    next_steps.append(step)
                except:
                    breakpoint()
            
            # Create new paths for each valid next step
            for step in next_steps:
                step = step.strip()
                
                # Check for completion
                if 'return True' in step or 'return False' in step or stop_token in step:
                    new_score = path["score"]
                    completed_paths.append({
                        "og_mc": og_mc,
                        "steps": path["steps"] + [step],
                        "score": new_score,
                        "scorelist": path["scorelist"] + [new_score],
                        "mc": get_mc(cot_so_far),
                        "completed": True
                    })
                else:
                    # Score and continue exploring
                    step_score = score_fn(cot_so_far, step, mode, cheat_idx=cheat_idx)
                    new_active_paths.append({
                        "og_mc": og_mc,
                        "steps": path["steps"] + [step],
                        "score": path["score"] + step_score,
                        "scorelist": path["scorelist"] + [step_score],
                        "mc": get_mc(cot_so_far),
                        "completed": False
                    })
        
        # Update active paths for next iteration
        active_paths = new_active_paths
        
        # Early stopping if all paths completed
        if not active_paths:
            break
        
        print(f"Step {step_idx + 1}: {len(active_paths)} active paths, {len(completed_paths)} completed")
    
    # Add any remaining active paths as completed (hit max_steps)
    for path in active_paths:
        path["completed"] = True
        completed_paths.append(path)
    
    # Sort by score (descending)
    completed_paths.sort(key=lambda x: x["score"], reverse=True)
    
    return completed_paths

if __name__ == '__main__':
    

    import os
    import pickle as pkl
    import json
    # breakpoint()
    # dump = pkl.load( open('/home/XXXX-c/clause/llama70_200_dump.pkl', 'rb')
    dump = pkl.load(open('/home/XXXX/folio_dump.pkl', 'rb'))
    # keys = 
    # for key, value in dump.items():
    skipped = 0
    llm = LLM()


    USER_PATH = '/home/XXXX/home/XXXX/6_26_backup/fs_backup_feb13/'

    dataset = '/home/XXXX/clause/dataset/folio.json'

    with open(dataset, 'r') as df:
        data = json.loads(df.read())

    acc = 0
    counter = 0
    # mode = 'raw-cheating'
    # mode = 'raw-conf'
    # mode = 'conf-bb'
    # mode = 'raw-bb'
    # mode = 'conf'
    mode = 'conf-bb-norm'
    # mode = 'conf-cheating'
    # config = 'bw-10-nsamples-10'
    # config = '5-4-oldnl2fol'
    # config = '10-5'
    config = '5-4-l2178'
    oopsies = []
    all_outs = {}
    idxs = list(range(len(dump)))
    # idxs = [289]
    # idxs = []
    # for i in range(len(data)):
    #     if 'false' in data[i]['label']: idxs.append(i)

    random.shuffle(idxs)
    all_outs['idxs'] = idxs
    logic_fail = []
    # idxs = [44]
    # idxs = idxs[:100]
    for key in (pbar := tqdm(idxs)):
        value = dump[key]
        # if key < 8: continue
        # if key == 57: continue
        # breakpoint()
        # logic = value['out'].split('user Now, its your turn:')[-1].split('def translation():')[1].strip('```')
        try:
            logic = 'Here are some facts and rules:\n' + value['out'].split('def solution():')[1].strip('```') + '#Let\'s think step by step:' 
        except: 
            logic_fail.append(key)
            continue
        # breakpoint()
        label = data[key]['label']
        if mode == 'raw-cheating' or mode == 'conf-cheating':
            cheat_idx = False
            if label.lower().strip() == 'true': cheat_idx = True
        else: cheat_idx=None
        # breakpoint()
        out = beam_search_cot(prompt=logic, llm=llm, score_fn= score, mode = mode, cheat_idx=cheat_idx)
        # out = exhaustive_search_cot(prompt=logic, llm=llm, score_fn = score, branches_per_step=3, max_steps=5, mode='conf')

        all_outs[key] = out

        # breakpoint()
        # for o in out:
        #     if -1000000 in o['scorelist']:
        #         zeroed.append(key)
        #         if counter == 0: cc = 1
        #         pbar.set_description('Mode: ' + mode + ' | n_zeroed:' + str(len(zeroed)) + ' | Acc: ' + str(acc / cc))
        #         break

                # continue
        preds = [0,0]
        for o in out:
            if o['completed'] != True: continue
            # if 'true' in o['steps'][-1].lower():
            #     preds[0] += 1
            # else: preds[1] += 1
            if o['mc']['pos'] > o['mc']['neg']: preds[1] += 1
            elif o['mc']['pos'] < o['mc']['neg']: preds[0] += 1
        if preds == [0,0]:
            oopsies.append(key)
            # all_outs[key]['oopsies'] = True
            all_outs['oopsies'] = oopsies
            for o in out:
                if o['completed'] != True: 
                    # print('incomplete')
                    continue
                if 'true' in o['steps'][-1].lower() and 'false' in o['steps'][-1].lower():
                    if len(o['steps'][-1].lower().split('true')[0]) < len(o['steps'][-1].lower().split('false')[0]): preds[0] += 1
                    elif len(o['steps'][-1].lower().split('true')[0]) > len(o['steps'][-1].lower().split('false')[0]): preds[1] += 1
                elif 'true' in o['steps'][-1].lower():
                    preds[0] += 1
                    # print(i, 'true', o['steps'][-1])
                else: 
                    preds[1] += 1
        predvote = np.argmax(preds)
        if ['true', 'false'][predvote] == label.lower().strip():
            acc += 1
        counter += 1

        pbar.set_description('Mode: ' + mode + ' | config: ' + config + ' | oopsies: ' + str(len(oopsies)) + ' |  Acc: ' + str(acc / counter))
        # breakpoint()
        
        # breakpoint()

        # breakpoint()
        all_outs['logic_fail'] = logic_fail
        pkl.dump(all_outs, open('/home/XXXX/clause/beam_outs_' + mode + '_' + config +  '.pkl', 'wb'))
