# Natural Language Toolkit: Interface to the Prover9 Theorem Prover
#
# Copyright (C) 2001-2023 NLTK Project
# Author: Dan Garrette <dhgarrette@gmail.com>
#         Ewan Klein <ewan@inf.ed.ac.uk>
#
# URL: <https://www.nltk.org/>
# For license information, see LICENSE.TXT
"""
A theorem prover that makes use of the external 'Prover9' package.
"""

import os
import subprocess
from copy import deepcopy
from itertools import product
from typing import List, Tuple

import nltk
from nltk.inference.api import BaseProverCommand, Prover
from nltk.sem.logic import (
    AllExpression,
    AndExpression,
    EqualityExpression,
    ExistsExpression,
    Expression,
    IffExpression,
    ImpExpression,
    NegatedExpression,
    OrExpression,
)

#
# Following is not yet used. Return code for 2 actually realized as 512.
#
p9_return_codes = {
    0: True,
    1: "(FATAL)",  # A fatal error occurred (user's syntax error).
    2: False,  # (SOS_EMPTY) Prover9 ran out of things to do
    #   (sos list exhausted).
    3: "(MAX_MEGS)",  # The max_megs (memory limit) parameter was exceeded.
    4: "(MAX_SECONDS)",  # The max_seconds parameter was exceeded.
    5: "(MAX_GIVEN)",  # The max_given parameter was exceeded.
    6: "(MAX_KEPT)",  # The max_kept parameter was exceeded.
    7: "(ACTION)",  # A Prover9 action terminated the search.
    101: "(SIGSEGV)",  # Prover9 crashed, most probably due to a bug.
}

class Prover9CommandParent:
    """
    A common base class used by both ``Prover9Command`` and ``MaceCommand``,
    which is responsible for maintaining a goal and a set of assumptions,
    and generating prover9-style input files from them.
    """

    def print_assumptions(self, output_format="nltk"):
        """
        Print the list of the current assumptions.
        """
        if output_format.lower() == "nltk":
            for a in self.assumptions():
                print(a)
        elif output_format.lower() == "prover9":
            for a in convert_to_prover9(self.assumptions()):
                print(a)
        else:
            raise NameError(
                "Unrecognized value for 'output_format': %s" % output_format
            )



class Prover9Command(Prover9CommandParent, BaseProverCommand):
    """
    A ``ProverCommand`` specific to the ``Prover9`` prover.  It contains
    the a print_assumptions() method that is used to print the list
    of assumptions in multiple formats.
    """

    def __init__(self, goal=None, assumptions=None, timeout=60, prover=None):
        """
        :param goal: Input expression to prove
        :type goal: sem.Expression
        :param assumptions: Input expressions to use as assumptions in
            the proof.
        :type assumptions: list(sem.Expression)
        :param timeout: number of seconds before timeout; set to 0 for
            no timeout.
        :type timeout: int
        :param prover: a prover.  If not set, one will be created.
        :type prover: Prover9
        """
        if not assumptions:
            assumptions = []

        if prover is not None:
            assert isinstance(prover, Prover9)
        else:
            prover = Prover9(timeout)

        BaseProverCommand.__init__(self, prover, goal, assumptions)


    def decorate_proof(self, proof_string, simplify=True):
        """
        :see BaseProverCommand.decorate_proof()
        """
        if simplify:
            return self._prover._call_prooftrans(proof_string, ["striplabels"])[
                0
            ].rstrip()
        else:
            return proof_string.rstrip()



class Prover9Parent:
    """
    A common class extended by both ``Prover9`` and ``Mace <mace.Mace>``.
    It contains the functionality required to convert NLTK-style
    expressions into Prover9-style expressions.
    """

    _binary_location = None

    def config_prover9(self, binary_location, verbose=False):
        if binary_location is None:
            self._binary_location = None
            self._prover9_bin = None
        else:
            name = "prover9"
            self._prover9_bin = nltk.internals.find_binary(
                name,
                path_to_bin=binary_location,
                env_vars=["PROVER9"],
                url="https://www.cs.unm.edu/~mccune/prover9/",
                binary_names=[name, name + ".exe"],
                verbose=verbose,
            )
            self._binary_location = self._prover9_bin.rsplit(os.path.sep, 1)


    def prover9_input(self, goal, assumptions):
        """
        :return: The input string that should be provided to the
            prover9 binary.  This string is formed based on the goal,
            assumptions, and timeout value of this object.
        """
        s = ""

        if assumptions:
            s += "formulas(assumptions).\n"
            for p9_assumption in convert_to_prover9(assumptions):
                s += "    %s.\n" % p9_assumption
            s += "end_of_list.\n\n"

        if goal:
            s += "formulas(goals).\n"
            s += "    %s.\n" % convert_to_prover9(goal)
            s += "end_of_list.\n\n"

        return s


    def binary_locations(self):
        """
        A list of directories that should be searched for the prover9
        executables.  This list is used by ``config_prover9`` when searching
        for the prover9 executables.
        """
        return [
            "/usr/local/bin/prover9",
            "/usr/local/bin/prover9/bin",
            "/usr/local/bin",
            "/usr/bin",
            "/usr/local/prover9",
            "/usr/local/share/prover9",
            "/home/qcw/My_Paper_Project/logic_reasoning_for_inverse_scaling/LADR-2009-11A/bin"
        ]


    def _find_binary(self, name, verbose=False):
        binary_locations = self.binary_locations()
        if self._binary_location is not None:
            binary_locations += [self._binary_location]
        return nltk.internals.find_binary(
            name,
            searchpath=binary_locations,
            env_vars=["PROVER9"],
            url="https://www.cs.unm.edu/~mccune/prover9/",
            binary_names=[name, name + ".exe"],
            verbose=verbose,
        )

    def _call(self, input_str, binary, args=[], verbose=False):
        """
        Call the binary with the given input.

        :param input_str: A string whose contents are used as stdin.
        :param binary: The location of the binary to call
        :param args: A list of command-line arguments.
        :return: A tuple (stdout, returncode)
        :see: ``config_prover9``
        """
        if verbose:
            print("Calling:", binary)
            print("Args:", args)
            print("Input:\n", input_str, "\n")

        # Call prover9 via a subprocess
        cmd = [binary] + args
        try:
            input_str = input_str.encode("utf8")
        except AttributeError:
            pass
        p = subprocess.Popen(
            cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, stdin=subprocess.PIPE
        )
        (stdout, stderr) = p.communicate(input=input_str)

        if verbose:
            print("Return code:", p.returncode)
            if stdout:
                print("stdout:\n", stdout, "\n")
            if stderr:
                print("stderr:\n", stderr, "\n")

        return (stdout.decode("utf-8"), p.returncode)



def convert_to_prover9(input):
    """
    Convert a ``logic.Expression`` to Prover9 format.
    """
    if isinstance(input, list):
        result = []
        for s in input:
            try:
                result.append(_convert_to_prover9(s.simplify()))
            except:
                print("input %s cannot be converted to Prover9 input syntax" % input)
                raise
        return result
    else:
        try:
            return _convert_to_prover9(input.simplify())
        except:
            print("input %s cannot be converted to Prover9 input syntax" % input)
            raise



def _convert_to_prover9(expression):
    """
    Convert ``logic.Expression`` to Prover9 formatted string.
    """
    if isinstance(expression, ExistsExpression):
        return (
            "exists "
            + str(expression.variable)
            + " "
            + _convert_to_prover9(expression.term)
        )
    elif isinstance(expression, AllExpression):
        return (
            "all "
            + str(expression.variable)
            + " "
            + _convert_to_prover9(expression.term)
        )
    elif isinstance(expression, NegatedExpression):
        return "-(" + _convert_to_prover9(expression.term) + ")"
    elif isinstance(expression, AndExpression):
        return (
            "("
            + _convert_to_prover9(expression.first)
            + " & "
            + _convert_to_prover9(expression.second)
            + ")"
        )
    elif isinstance(expression, OrExpression):
        return (
            "("
            + _convert_to_prover9(expression.first)
            + " | "
            + _convert_to_prover9(expression.second)
            + ")"
        )
    elif isinstance(expression, ImpExpression):
        return (
            "("
            + _convert_to_prover9(expression.first)
            + " -> "
            + _convert_to_prover9(expression.second)
            + ")"
        )
    elif isinstance(expression, IffExpression):
        return (
            "("
            + _convert_to_prover9(expression.first)
            + " <-> "
            + _convert_to_prover9(expression.second)
            + ")"
        )
    elif isinstance(expression, EqualityExpression):
        return (
            "("
            + _convert_to_prover9(expression.first)
            + " = "
            + _convert_to_prover9(expression.second)
            + ")"
        )
    else:
        return str(expression)


class Prover9(Prover9Parent, Prover):
    _prover9_bin = None
    _prooftrans_bin = None

    def __init__(self, timeout=60):
        self._timeout = timeout
        """The timeout value for prover9.  If a proof can not be found
           in this amount of time, then prover9 will return false.
           (Use 0 for no timeout.)"""


    def _prove(self, goal=None, assumptions=None, verbose=False):
        """
        Use Prover9 to prove a theorem.
        :return: A pair whose first element is a boolean indicating if the
        proof was successful (i.e. returns value of 0) and whose second element
        is the output of the prover.
        """
        if not assumptions:
            assumptions = []

        stdout, returncode = self._call_prover9(
            self.prover9_input(goal, assumptions), verbose=verbose
        )
        return (returncode == 0, stdout)

    def prover9_input(self, goal, assumptions):
        """
        :see: Prover9Parent.prover9_input
        """
        s = "clear(auto_denials).\n"  # only one proof required
        return s + Prover9Parent.prover9_input(self, goal, assumptions)


    def _call_prover9(self, input_str, args=[], verbose=False):
        """
        Call the ``prover9`` binary with the given input.

        :param input_str: A string whose contents are used as stdin.
        :param args: A list of command-line arguments.
        :return: A tuple (stdout, returncode)
        :see: ``config_prover9``
        """
        if self._prover9_bin is None:
            self._prover9_bin = self._find_binary("prover9", verbose)

        updated_input_str = ""
        if self._timeout > 0:
            updated_input_str += "assign(max_seconds, %d).\n\n" % self._timeout
        updated_input_str += input_str

        stdout, returncode = self._call(
            updated_input_str, self._prover9_bin, args, verbose
        )

        if returncode not in [0, 2]:
            errormsgprefix = "%%ERROR:"
            if errormsgprefix in stdout:
                msgstart = stdout.index(errormsgprefix)
                errormsg = stdout[msgstart:].strip()
            else:
                errormsg = None
            if returncode in [3, 4, 5, 6]:
                raise Prover9LimitExceededException(returncode, errormsg)
            else:
                raise Prover9FatalException(returncode, errormsg)

        return stdout, returncode

    def _call_prooftrans(self, input_str, args=[], verbose=False):
        """
        Call the ``prooftrans`` binary with the given input.

        :param input_str: A string whose contents are used as stdin.
        :param args: A list of command-line arguments.
        :return: A tuple (stdout, returncode)
        :see: ``config_prover9``
        """
        if self._prooftrans_bin is None:
            self._prooftrans_bin = self._find_binary("prooftrans", verbose)

        return self._call(input_str, self._prooftrans_bin, args, verbose)



class Prover9Exception(Exception):
    def __init__(self, returncode, message):
        msg = p9_return_codes[returncode]
        if message:
            msg += "\n%s" % message
        Exception.__init__(self, msg)


class Prover9FatalException(Prover9Exception):
    pass


class Prover9LimitExceededException(Prover9Exception):
    pass


def prover9_prove(arguments):
    """
    Try some proofs and exhibit the results.
    """
    for (goal, assumptions) in arguments:
        g = Expression.fromstring(goal)
        alist = [Expression.fromstring(a) for a in assumptions]
        p, stdout = Prover9Command(g, assumptions=alist).prove(verbose=False)
        # for a in alist:
        #     print("   %s" % a)
        # print(f"|- {g}: {p}\n")
        
    return p, stdout


arguments = [
    (
        (
            "some x. (bezier(severin) | (slide(sergeant) -> distract(severin, purcell)))"
        ),
        [
            "all x. (slide(x))",
            # "all x. (ugly(x) -> not popular(x))",
            # "all x. (love(children, x) -> funny(x))",
            # "all x. (simpsons(x) -> love(children, x))",
            # "all x. (yellow(x) -> simpsons(x))",
            # "not (simpsons(ben) <-> funny(ben))"
        ]
    )
]

def generate_boolean_combinations(num_variables):
    return list(product([True, False], repeat=num_variables))
        

def demo():
    print("Testing proofs")
    variables = ["bezier(severin)", "distract(severin, purcell)"]
    
    truth_table = generate_boolean_combinations(len(variables))
    
    facts_list = []
    for i in range(len(truth_table)):
        current_facts = []
        for j in range(len(truth_table[i])):
            if truth_table[i][j]:
                current_facts.append(variables[j])
            else:
                current_facts.append(f"not {variables[j]}")
        facts_list.append(current_facts)
    
    for i in range(len(facts_list)):
        argument_list = deepcopy(arguments)
        argument_list[0][1].extend(facts_list[i])
        result, std_out = prover9_prove(argument_list)
        
        print(f"{truth_table[i]} |- {result}")


class FOL2Prover9Converter:
    """
    This is the converter to convert standard First Order Logic expression into the format that nltk prover9 accepts
    """
    
    def __init__(self) -> None:
        self.symbol2text = {
            "¬": "not ",
            "∃": "some ",
            "∀": "all ",
            "→": "->",
            "⟷": "<->",
            "∧": "&",
            "∨": "|",
            "↔": "<->",
        }
        self.lowercase_alphabet = [chr(i) for i in range(97, 123)]
        self.lowercase_alphabet += [str(i) for i in range(10)]
        
    def convert_fol_instance(self, assumption: List[str], goal: str) -> Tuple[List[str], str]:
        """convert a folio data instance.

        Args:
            assumption: a list of fol expression, including the facts and rules of current data instance
            goal: the fol expression that need to be proved True or False.

        Returns:
            the converted list of assumptions and goal
        """
        return [self.convert_expression(item) for item in assumption], self.convert_expression(goal)
        
        
    def convert_expression(self, fol_expression: str) -> str:
        """acceptes a first order logic as input and output the corresponding format

        Args:
            fol_expression (str): the standard first order logic expression

        Returns:
            str: converted version of the input expression
        """
        temp_expression = fol_expression
        for key in self.symbol2text:
            temp_expression = temp_expression.replace(key, self.symbol2text[key])
            
        # lower the character
        # tt = ""
        # for char in temp_expression:
        #     if char.isupper():
        #         tt += char.lower()
        #     else:
        #         tt += char
        
        temp_expression = temp_expression.replace("  ", " ")
        
        # modify "all" and "some"
        temp_expression = temp_expression.replace("all x all y all z", "all *x y z.")
        temp_expression = temp_expression.replace("all x all y", "all *x y.")
        temp_expression = temp_expression.replace("all x", "all x.")
        temp_expression = temp_expression.replace("*", "")
        
        temp_expression = temp_expression.replace("some x some y some z", "some *x y z.")
        temp_expression = temp_expression.replace("some x some y", "some *x y.")
        temp_expression = temp_expression.replace("some x", "some x.")
        temp_expression = temp_expression.replace("*", "")
        
        # modify ⊕
        if '⊕' in temp_expression:
            while temp_expression.find('⊕') != -1:
                symbol_index = temp_expression.find('⊕')
                
                # find left edge
                bracket_list = []
                left_edge = symbol_index - 1
                
                while left_edge > 0:
                    if temp_expression[left_edge] == ')':
                        bracket_list.append(')')
                    elif temp_expression[left_edge] == '(':
                        if len(bracket_list) == 0:
                            left_edge += 1
                            break
                        else:
                            assert bracket_list.pop() == ')'
                    left_edge -= 1
                
                # find right edge
                bracket_list = []
                right_edge = symbol_index + 1
                
                while right_edge < len(temp_expression):
                    if temp_expression[right_edge] == '(':
                        bracket_list.append('(')
                    elif temp_expression[right_edge] == ')':
                        if len(bracket_list) == 0:
                            break
                        else:
                            assert bracket_list.pop() == '('
                    right_edge += 1
                
                if (left_edge - 1 >= 0 and temp_expression[left_edge - 1] == '(') and (right_edge + 1 < len(temp_expression) and temp_expression[right_edge + 1] == ')'):
                    extracted_expression = temp_expression[left_edge - 1:right_edge + 1]
                    temp_expression = f"{temp_expression[:left_edge - 1]}(not {extracted_expression.replace('⊕', '<->', 1)}){temp_expression[right_edge + 1:]}"
                else:                
                    extracted_expression = temp_expression[left_edge:right_edge]
                    temp_expression = f"{temp_expression[:left_edge]}(not ({extracted_expression.replace('⊕', '<->', 1)})){temp_expression[right_edge:]}"
                    
        return temp_expression


if __name__ == "__main__":
    demo()


