﻿class FOLFormula(object):
	def __init__(self):
		return

class FOLAnd(FOLFormula):
	def __init__(self, conjuncts):
		self.operands = conjuncts

	def __eq__(self, other):
		if type(other) != FOLAnd:
			return False
		if len(self.operands) != len(other.operands):
			return False
		for i in range(len(self.operands)):
			if self.operands[i] != other.operands[i]:
				return False
		return True

	def __ne__(self, other):
		if type(other) != FOLAnd:
			return True
		if len(self.operands) != len(other.operands):
			return True
		for i in range(len(self.operands)):
			if self.operands[i] != other.operands[i]:
				return True
		return False

	def __hash__(self):
		value = 0
		for operand in self.operands:
			value ^= hash(operand)
		return hash((FOLAnd, value))

	def apply(self, func):
		operands = []
		for operand in self.operands:
			operands.append(func(operand))
		return FOLAnd(operands)

	def visit(self, visit):
		for operand in self.operands:
			visit(operand)

class FOLOr(FOLFormula):
	def __init__(self, disjuncts):
		self.operands = disjuncts

	def __eq__(self, other):
		if type(other) != FOLOr:
			return False
		if len(self.operands) != len(other.operands):
			return False
		for i in range(len(self.operands)):
			if self.operands[i] != other.operands[i]:
				return False
		return True

	def __ne__(self, other):
		if type(other) != FOLOr:
			return True
		if len(self.operands) != len(other.operands):
			return True
		for i in range(len(self.operands)):
			if self.operands[i] != other.operands[i]:
				return True
		return False

	def __hash__(self):
		value = 0
		for operand in self.operands:
			value ^= hash(operand)
		return hash((FOLOr, value))

	def apply(self, func):
		operands = []
		for operand in self.operands:
			operands.append(func(operand))
		return FOLOr(operands)

	def visit(self, visit):
		for operand in self.operands:
			visit(operand)

class FOLNot(FOLFormula):
	def __init__(self, operand):
		self.operand = operand

	def __eq__(self, other):
		if type(other) != FOLNot:
			return False
		return self.operand == other.operand

	def __ne__(self, other):
		if type(other) != FOLNot:
			return True
		return self.operand != other.operand

	def __hash__(self):
		return hash((FOLNot, self.operand))

	def apply(self, func):
		return FOLNot(func(self.operand))

	def visit(self, visit):
		visit(self.operand)

class FOLIfThen(FOLFormula):
	def __init__(self, antecedent, consequent):
		self.antecedent = antecedent
		self.consequent = consequent

	def __eq__(self, other):
		if type(other) != FOLIfThen:
			return False
		return self.antecedent == other.antecedent and self.consequent == other.consequent

	def __ne__(self, other):
		if type(other) != FOLIfThen:
			return True
		return self.antecedent != other.antecedent or self.consequent != other.consequent

	def __hash__(self):
		return hash((FOLIfThen, self.antecedent, self.consequent))

	def apply(self, func):
		return FOLIfThen(func(self.antecedent), func(self.consequent))

	def visit(self, visit):
		visit(self.antecedent)
		visit(self.consequent)

class FOLIff(FOLFormula):
	def __init__(self, antecedent, consequent):
		self.antecedent = antecedent
		self.consequent = consequent

	def __eq__(self, other):
		if type(other) != FOLIff:
			return False
		return self.antecedent == other.antecedent and self.consequent == other.consequent

	def __ne__(self, other):
		if type(other) != FOLIff:
			return True
		return self.antecedent != other.antecedent or self.consequent != other.consequent

	def __hash__(self):
		return hash((FOLIff, self.antecedent, self.consequent))

	def apply(self, func):
		return FOLIff(func(self.antecedent), func(self.consequent))

	def visit(self, visit):
		visit(self.antecedent)
		visit(self.consequent)

class FOLEquals(FOLFormula):
	def __init__(self, left, right):
		self.left = left
		self.right = right

	def __eq__(self, other):
		if type(other) != FOLEquals:
			return False
		return self.left == other.left and self.right == other.right

	def __ne__(self, other):
		if type(other) != FOLEquals:
			return True
		return self.left != other.left or self.right != other.right

	def __hash__(self):
		return hash((FOLEquals, self.left, self.right))

	def apply(self, func):
		return FOLEquals(func(self.left), func(self.right))

	def visit(self, visit):
		visit(self.left)
		visit(self.right)

class FOLForAll(FOLFormula):
	def __init__(self, variable, operand):
		self.variable = variable
		self.operand = operand

	def __eq__(self, other):
		if type(other) != FOLForAll:
			return False
		return self.variable == other.variable and self.operand == other.operand

	def __ne__(self, other):
		if type(other) != FOLForAll:
			return True
		return self.variable != other.variable or self.operand != other.operand

	def __hash__(self):
		return hash((FOLForAll, self.variable, self.operand))

	def apply(self, func):
		return FOLForAll(self.variable, func(self.operand))

	def visit(self, visit):
		visit(self.operand)

class FOLExists(FOLFormula):
	def __init__(self, variable, operand):
		self.variable = variable
		self.operand = operand

	def __eq__(self, other):
		if type(other) != FOLExists:
			return False
		return self.variable == other.variable and self.operand == other.operand

	def __ne__(self, other):
		if type(other) != FOLExists:
			return True
		return self.variable != other.variable or self.operand != other.operand

	def __hash__(self):
		return hash((FOLExists, self.variable, self.operand))

	def apply(self, func):
		return FOLExists(self.variable, func(self.operand))

	def visit(self, visit):
		visit(self.operand)

class FOLFuncApplication(FOLFormula):
	def __init__(self, function, args):
		self.function = function
		self.args = args

	def __eq__(self, other):
		if type(other) != FOLFuncApplication:
			return False
		if len(self.args) != len(other.args):
			return False
		if self.function != other.function:
			return False
		for i in range(len(self.args)):
			if self.args[i] != other.args[i]:
				return False
		return True

	def __ne__(self, other):
		if type(other) != FOLFuncApplication:
			return True
		if len(self.args) != len(other.args):
			return True
		if self.function != other.function:
			return True
		for i in range(len(self.args)):
			if self.args[i] != other.args[i]:
				return True
		return False

	def __hash__(self):
		value = 0
		for operand in self.args:
			value ^= hash(operand)
		return hash((FOLFuncApplication, self.function, value))

	def apply(self, func):
		return FOLFuncApplication(self.function, [func(arg) for arg in self.args])

	def visit(self, visit):
		for arg in self.args:
			visit(arg)

class FOLTerm(object):
	def __init__(self):
		return

class FOLVariable(FOLTerm):
	def __init__(self, variable):
		self.variable = variable

	def __eq__(self, other):
		if type(other) != FOLVariable:
			return False
		return self.variable == other.variable

	def __ne__(self, other):
		if type(other) != FOLVariable:
			return True
		return self.variable != other.variable

	def __hash__(self):
		return hash((FOLVariable, self.variable))

	def apply(self, func):
		return FOLVariable(self.variable)

	def visit(self, visit):
		return

class FOLConstant(FOLTerm):
	def __init__(self, constant):
		self.constant = constant

	def __eq__(self, other):
		if type(other) != FOLConstant:
			return False
		return self.constant == other.constant

	def __ne__(self, other):
		if type(other) != FOLConstant:
			return True
		return self.constant != other.constant

	def __hash__(self):
		return hash((FOLConstant, self.constant))

	def apply(self, func):
		return FOLConstant(self.constant)

	def visit(self, visit):
		return

class FOLNumber(FOLTerm):
	def __init__(self, number):
		self.number = number

	def __eq__(self, other):
		if type(other) != FOLNumber:
			return False
		return self.number == other.number

	def __ne__(self, other):
		if type(other) != FOLNumber:
			return True
		return self.number != other.number

	def __hash__(self):
		return hash((FOLNumber, self.number))

	def apply(self, func):
		return FOLNumber(self.number)

	def visit(self, visit):
		return

def substitute(formula, src, dst):
	def apply_substitute(f):
		if f == src:
			return dst
		elif type(src) == FOLVariable and type(f) in [FOLForAll, FOLExists] and src.variable == f.variable:
			return f
		else:
			return f.apply(apply_substitute)
	return apply_substitute(formula)

def copy(formula):
	def apply_identity(f):
		return f.apply(apply_identity)
	return apply_identity(formula)

def max_variable(formula):
	max_var = 0
	def max_variable_visit(f):
		nonlocal max_var
		if type(f) == FOLVariable or type(f) == FOLForAll or type(f) == FOLExists:
			max_var = max(max_var, f.variable)
		f.visit(max_variable_visit)
	max_variable_visit(formula)
	return max_var

def free_variables(formula):
	free_variables = []
	bound_variables = []
	def free_variables_visit(f):
		nonlocal bound_variables, free_variables
		if type(f) == FOLVariable and f.variable not in bound_variables and f.variable not in free_variables:
			free_variables.append(f.variable)
		if type(f) == FOLForAll or type(f) == FOLExists:
			bound_variables.append(f.variable)
		f.visit(free_variables_visit)
		if type(f) == FOLForAll or type(f) == FOLExists:
			bound_variables.pop()
	free_variables_visit(formula)
	return free_variables

def bound_variables(formula):
	bound_variables = []
	def bound_variables_visit(f):
		nonlocal bound_variables
		if (type(f) == FOLForAll or type(f) == FOLExists) and f.variable not in bound_variables:
			bound_variables.append(f.variable)
		f.visit(bound_variables_visit)
	bound_variables_visit(formula)
	return bound_variables

def predicates(formula):
	predicate_list = []
	def predicates_visit(f):
		nonlocal predicate_list
		if type(f) == FOLFuncApplication:
			predicate_list.append(f.function)
		f.visit(predicates_visit)
	predicates_visit(formula)
	return predicate_list

def contains(formula, subtree):
	found = False
	def visit(f):
		nonlocal found
		if found:
			return
		elif f == subtree:
			found = True
		else:
			f.visit(visit)
	formula.visit(visit)
	return found

def parse_fol_term_from_tptp(string, pos, var_map):
	end = pos
	if string[end] == '"':
		end += 1
		while True:
			if string[end] == '"' and string[end - 1] != '\\':
				break
			end += 1
		end += 1
		return FOLConstant(string[pos:end]), end

	must_be_var = False
	if string[end] == '$':
		must_be_var = True
		end += 1
	while string[end].isdigit() or string[end].isalpha() or string[end] == '≥' or string[end] == '.' or string[end] == '_':
		end += 1
	if string[pos:end] in var_map:
		term = FOLVariable(var_map[string[pos:end]])
	else:
		if must_be_var:
			raise Exception(f"parse_fol_term_from_tptp ERROR at {pos+1}: Variable undeclared.")
		term = FOLConstant(string[pos:end])
	if string[end] == '(':
		end += 1
		args = []
		while True:
			arg, end = parse_fol_term_from_tptp(string, end, var_map)
			args.append(arg)
			if string[end] == ',':
				end += 1
				continue
			elif string[end] == ')':
				end += 1
				break
			else:
				raise Exception(f"parse_fol_term_from_tptp ERROR at {end+1}: Expected a closing parenthesis or comma in argument array.")
		if type(term) == FOLConstant:
			term = FOLFuncApplication(term.constant, args)
		else:
			term = FOLFuncApplication(term, args)
	return term, end

def parse_fol_quantifier_from_tptp(string, pos, var_map):
	index = string.find(']:', pos)
	var_names = string[pos:index].split(',')
	variables = []
	for var_name in var_names:
		if var_name in var_map:
			raise Exception(f"parse_fol_from_tptp ERROR at {pos+1}: Variable '{var_name}' already declared in this scope.")
		variable = len(var_map) + 1
		var_map[var_name] = variable
		variables.append(variable)
	(formula, pos) = parse_fol_literal_from_tptp(string, index + len(']:'), var_map)
	for var_name in var_names:
		del var_map[var_name]
	return variables, formula, pos

def parse_fol_literal_from_tptp(string, pos, var_map):
	if string.startswith('?[', pos):
		(variables, formula, pos) = parse_fol_quantifier_from_tptp(string, pos + len('?['), var_map)
		for variable in reversed(variables):
			formula = FOLExists(variable, formula)
	elif string.startswith('![', pos):
		(variables, formula, pos) = parse_fol_quantifier_from_tptp(string, pos + len('!['), var_map)
		for variable in reversed(variables):
			formula = FOLForAll(variable, formula)
	elif string.startswith('~', pos):
		(operand, pos) = parse_fol_literal_from_tptp(string, pos + 1, var_map)
		formula = FOLNot(operand)
	elif string.startswith('(', pos):
		(formula, pos) = parse_fol_from_tptp(string, pos + 1, var_map)
		if string[pos] != ')':
			raise Exception(f"parse_fol_from_tptp ERROR at {pos+1}: Expected a closing parenthesis.")
		pos += 1
	else:
		formula, pos = parse_fol_term_from_tptp(string, pos, var_map)
		if pos < len(string) and string[pos] == '=':
			other, pos = parse_fol_literal_from_tptp(string, pos + 1, var_map)
			formula = FOLEquals(formula, other)
	return formula, pos

def parse_fol_from_tptp(string, pos, var_map):
	while string[pos].isspace():
		pos += 1
	if string.startswith('(', pos):
		(formula, pos) = parse_fol_from_tptp(string, pos + 1, var_map)
		if string[pos] != ')':
			raise Exception(f"parse_fol_from_tptp ERROR at {pos+1}: Expected a closing parenthesis.")
		return formula, pos + 1
	formula, pos = parse_fol_literal_from_tptp(string, pos, var_map)

	while pos < len(string) and string[pos].isspace():
		pos += 1
	if pos < len(string) and (string[pos] == '&' or string[pos] == '|'):
		operands = [formula]
		operator = string[pos]
		pos += 1
		while True:
			while string[pos].isspace():
				pos += 1
			other, pos = parse_fol_literal_from_tptp(string, pos, var_map)
			operands.append(other)
			if pos == len(string):
				break
			while string[pos].isspace():
				pos += 1
			if string[pos] == operator:
				pos += 1
			else:
				break
		if operator == '&':
			formula = FOLAnd(operands)
		else:
			formula = FOLOr(operands)
	if string.startswith('=>', pos):
		pos += len('=>')
		while string[pos].isspace():
			pos += 1
		other, pos = parse_fol_from_tptp(string, pos, var_map)
		return FOLIfThen(formula, other), pos
	else:
		return formula, pos

def do_parse_fol_from_tptp(string):
	var_map = {}
	formula, pos = parse_fol_from_tptp(string, 0, var_map)
	return formula

def fol_to_tptp(formula):
	if type(formula) == FOLVariable:
		return f'X{formula.variable}'
	elif type(formula) == FOLConstant:
		return formula.constant
	elif type(formula) == FOLNumber:
		return formula.number
	elif type(formula) == FOLFuncApplication:
		return formula.function + '(' + ','.join([fol_to_tptp(arg) for arg in formula.args]) + ')'
	elif type(formula) == FOLAnd:
		return '(' + ' & '.join([fol_to_tptp(operand) for operand in formula.operands]) + ')'
	elif type(formula) == FOLOr:
		return '(' + ' | '.join([fol_to_tptp(operand) for operand in formula.operands]) + ')'
	elif type(formula) == FOLNot:
		return '~' + fol_to_tptp(formula.operand)
	elif type(formula) == FOLIfThen:
		return '(' + fol_to_tptp(formula.antecedent) + ' => ' + fol_to_tptp(formula.consequent) + ')'
	elif type(formula) == FOLIff:
		return '(' + fol_to_tptp(formula.antecedent) + ' <=> ' + fol_to_tptp(formula.consequent) + ')'
	elif type(formula) == FOLEquals:
		return '(' + fol_to_tptp(formula.left) + '=' + fol_to_tptp(formula.right) + ')'
	elif type(formula) == FOLForAll:
		return f'![X{formula.variable}]:(' + fol_to_tptp(formula.operand) + ')'
	elif type(formula) == FOLExists:
		return f'?[X{formula.variable}]:(' + fol_to_tptp(formula.operand) + ')'
	elif type(formula) == FOLFuncApplication:
		return formula.function + '(' + ','.join([fol_to_tptp(arg) for arg in formula.args]) + ')'
	else:
		raise Exception(f"fol_to_tptp ERROR: Unrecognized formula type {type(formula)}.")

def unify_term(first, second, first_var_map, second_var_map):
	if type(first) == FOLVariable:
		if not isinstance(second, FOLTerm):
			return False
		if first.variable in first_var_map:
			return first_var_map[first.variable] == second
		else:
			first_var_map[first.variable] = second
			return True
	elif type(second) == FOLVariable:
		if not isinstance(first, FOLTerm):
			return False
		if second.variable in second_var_map:
			return second_var_map[second.variable] == first
		else:
			second_var_map[second.variable] = first
			return True
	else:
		return first == second

def unify(first, second, first_var_map, second_var_map):
	if type(first) != type(second):
		return False
	elif type(first) == FOLAnd or type(first) == FOLOr:
		if len(first.operands) != len(second.operands):
			return False
		for i in range(len(first.operands)):
			if not unify(first.operands[i], second.operands[i], first_var_map, second_var_map):
				return False
	elif type(first) == FOLNot:
		return unify(first.operand, second.operand, first_var_map, second_var_map)
	elif type(first) == FOLIfThen or type(first) == FOLIff:
		return unify(first.antecedent, second.antecedent, first_var_map, second_var_map) and unify(first.consequent, second.consequent, first_var_map, second_var_map)
	elif type(first) == FOLEquals:
		return unify_term(first.left, second.left, first_var_map, second_var_map) and unify_term(first.right, second.right, first_var_map, second_var_map)
	elif type(first) == FOLForAll or type(first) == FOLExists:
		first_var_map[first.variable] = FOLVariable(second.variable)
		second_var_map[second.variable] = FOLVariable(first.variable)
		if not unify(first.operand, second.operand, first_var_map, second_var_map):
			return False
		del first_var_map[first.variable]
		del second_var_map[second.variable]
	elif type(first) == FOLFuncApplication:
		if first.function != second.function or len(first.args) != len(second.args):
			return False
		for i in range(len(first.args)):
			if not unify_term(first.args[i], second.args[i], first_var_map, second_var_map):
				return False
	else:
		raise Exception(f"unify ERROR: Unrecognized formula type {type(first)}.")
	return True
