import json
import os
import numpy
import re

from pkg_resources import UnknownExtra
from sympy import sympify
from sympy.logic.boolalg import Or, And, Not, to_dnf, to_cnf, simplify_logic
from decimal import Decimal
global_pre_argu_map = {}

def knowledge_value_neg(value):
    if value == "Unknown" or value == "UNKNOWN" or value == "unknown":
        return "Unknown"
    elif value == "True" or value == "TRUE" or value == "true":
        return "False"
    elif value == "False" or value == "FALSE" or value == "false":
        return "True"
    else:
        print("Wrong format of the value, only 'Unknown', 'True', and 'False' are allowed.")
        return "Unknown"

def remove_last_zero(float_str):
    if '.' in float_str:
        float_str = float_str.rstrip('0')  # Remove trailing zeros
        if float_str.endswith('.'):  # Remove trailing dot if necessary
            float_str = float_str[:-1]

    if float_str =='-0' or float_str == '+0':
        return '0'

    return float_str

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

def decimal_eval(expression: str) -> Decimal:
    # Replace all numbers in the expression with Decimal("number")
    # This regex matches integers and decimals.
    pattern = re.compile(r'(\d+\.\d+|\d+)')
    expression_with_decimal = pattern.sub(r"Decimal('\1')", expression)

    # Evaluate the modified expression safely.
    # Note: Only basic arithmetic operators are supported.
    result = eval(expression_with_decimal, {"Decimal": Decimal})
    return result

def to_DeNF(rule):
    """
    Convert a general proposition rule "? => ?" into a set of DeNF rules.
    :param rule:
    :return: the set of DeNF rules
    for example,
    the input: ~ (p_1 & p_2) => p_3
    the output: [~ p1 => p_3, ~ p_2 => p_3]
    """
    # DeNF_set_original = set()
    DeNF_set_simplified = set()

    lhs = rule.split('=>')[0].strip()
    rhs = rule.split('=>')[1].strip()
    lhs_expr = sympify(lhs)
    rhs_expr = sympify(rhs)
    lhs_dnf = to_dnf(lhs_expr, simplify=True)
    rhs_cnf = to_cnf(rhs_expr, simplify=True)
    lhs_dnf_clauses = lhs_dnf.args if isinstance(lhs_dnf, Or) else [lhs_dnf]
    rhs_cnf_clauses = rhs_cnf.args if isinstance(rhs_cnf, And) else [rhs_cnf]

    for each_dnf_clause in lhs_dnf_clauses:
        for each_cnf_clause in rhs_cnf_clauses:
            rhs_literals = each_cnf_clause.args if isinstance(each_cnf_clause, Or) else [each_cnf_clause]

            for single_rhs_literal in rhs_literals:
                lhs_original = And(each_dnf_clause, *[Not(lit) for lit in rhs_literals if lit != single_rhs_literal])
                lhs_simplified = simplify_logic(lhs_original)
                DeNF_rule_simplified = str(lhs_simplified) + " => " + str(single_rhs_literal)
                if lhs_simplified == False or str(single_rhs_literal) in str(
                        lhs_simplified):
                    continue
                else:
                    DeNF_set_simplified.add(DeNF_rule_simplified)

    return DeNF_set_simplified


def normalize_predicate(predicate, variables):
    """ Ensure predicates follow the variable order and update variable names in descriptions """
    match = re.match(r'([A-Z][a-zA-Z_]*)\((.*?)\)\s*:=\s*(.+)', predicate)
    if not match:
        print("Error! Wrong predicate definition of " + predicate + "\nPlease follow the instruction.")
        exit(0)

    predicate_name, args_str, definition = match.groups()
    args = args_str.split(',') if args_str else []
    if len(args) > len(variables):
        print("Error! More Variables are used in predicate \'" + predicate + "\' than the ones predefined!")
        print("Variables predefined: ", variables)
        exit(0)

    sorted_args = [variables[i] for i in range(len(args))]  # Reassign variables based on order

    # Update variable names in the definition
    for original, updated in zip(args, sorted_args):
        definition = re.sub(rf'\b{original}\b', updated, definition) # \bat\b match: at, at., (at), as at ay. Not match: attempt, atlas

    normalized_predicate = f"{predicate_name}({', '.join(sorted_args)}) := {definition}"

    if predicate_name in global_pre_argu_map.keys():
        global_pre_argu_map[predicate_name] = max(len(sorted_args), global_pre_argu_map[predicate_name])
    else:
        global_pre_argu_map[predicate_name] = len(sorted_args)

    return normalized_predicate


def clean_operators4vars(args):
    """remove the operators in variables"""
    cleaned_args_set = set()
    for arg in args:
        clean_var = re.findall(r'\b[a-zA-Z_]\w*\b', arg)
        cleaned_args_set.update(clean_var)
    return cleaned_args_set


def normalize_rule(rule, variables, predicateNameSet):
    """ Reassign variables in the rule according to the given variable order """
    match = re.findall(r'([A-Z][a-zA-Z_]*)\((.*?)\)', rule)
    match_variable_set = set()
    if match:
        for pred, args_str in match:
            args = args_str.replace(' ', '').split(',') if args_str else []
            # remove the potential persevered operator for variables
            # currently support: +/-/*/\
            cleaned_args = clean_operators4vars(args)
            match_variable_set.update(cleaned_args)

    # check variables_used number is smaller than variables
    if not len(match_variable_set) <= len(variables):
        print("More variables are used beyond the ones predefined.")
        exit(0)

    normalized_variables = variables[0:len(match_variable_set)]
    dict_with_sets = dict(zip(sorted(match_variable_set), sorted(normalized_variables)))
    if match:
        for pred, args_str in match:
            pattern = r'\b[a-zA-Z_]\w*\b'
            new_args_str = re.sub(pattern,
                                      lambda matchNew: dict_with_sets.get(matchNew.group(0), matchNew.group(0)),
                                      args_str)
            sorted_args = new_args_str.replace(' ', '').split(',') if new_args_str else []
            updated_pred = f"{pred}({', '.join(sorted_args)})"

            if pred not in predicateNameSet:
                print("Error. Undefined predicate \'" + pred + "(*)\' in the rule: " + rule)
                print("All predicates defined are: " + str(predicateNameSet))
                exit(0)

            if int(global_pre_argu_map[pred]) < len(sorted_args):
                print("Wrong usage of predicate \'" + pred + "(*)\':", global_pre_argu_map[pred],
                      "arguments are allowed at most with", str(len(sorted_args)),
                      "provided!")
                exit(0)
            escaped_args_str = re.escape(args_str) # escape special characters in a string: b'()[]{}?*+-|^$\\.&~# \t\n\r\v\f
            rule = re.sub(rf'\b{pred}\({escaped_args_str}\)', updated_pred, rule)

    return rule


def remove_redundant_rules(rules):
    """ Remove redundant rules by normalizing variable names """
    clean_rules = {}

    for rule in rules:
        rule_structure = rule.strip()  # Normalize variable names
        clean_rules[rule_structure] = rule  # Keep one instance of the rule

    return list(clean_rules.values())


def pre_process(data):
    """ Pre-process the esl file """
    for entry in data:
        entry["variables"] = [i.replace(" ", '') for i in entry['variables']]
        entry["predicates"] = [predicate.split(":=")[0].replace(" ", '') + " := " + re.sub(r"\s+", ' ', (
            predicate.split(":=")[1].strip())) for predicate in entry["predicates"]]
        entry["rules"] = [rule.split("=>")[0].replace(" ", '') + " => " + rule.split("=>")[1].replace(" ", '') for
                          rule
                          in entry["rules"]]
        entry["rules"] = [rule.replace('&', " & ") for rule in entry["rules"]]
        entry["rules"] = [rule.replace('|', " | ") for rule in entry["rules"]]
    return data


def normalize_json(data, output_file):
    data = pre_process(data)  # normalize space usage

    # Initialize merged structure
    merged_data = {
        "domain": set(),
        "variables": set(),
        "predicates": {},  # Store predicates as a dictionary to merge duplicates
        "rules": set()
    }

    configuration = {
        "variables": 0,
        "predicates": {},  # store predicate name as key, number of arguments as value
    }

    config_pre = {}

    # Merge data_original from all entries
    for entry in data:
        merged_data["domain"].update(entry["domain"])
        merged_data["variables"].update(entry["variables"])
        merged_data["rules"].update(entry["rules"])

        for predicate in entry["predicates"]:
            # key = re.sub(r"\(.*?\)", "", predicate.split(" := ")[0])  # Extract predicate name without variables
            key = re.sub(r"\(.*?\).*", "", predicate).strip()  # Extract predicate name without variables

            if key in merged_data["predicates"]:
                print("we find overlapping predicates as follow and we remain the last one: " + predicate)
                print(merged_data["predicates"][key])
                print(predicate)

                if merged_data["predicates"][key] != predicate:
                    # merged_data["predicates"][key] = merged_data["predicates"][key] + " / " + predicate
                    merged_data["predicates"][key] = predicate
            else:
                merged_data["predicates"][key] = predicate

    merged_data["variables"] = sorted(merged_data["variables"])
    variables = merged_data["variables"]

    keySet = list(merged_data["predicates"].keys())

    merged_data["predicates"] = sorted(
        [normalize_predicate(p, variables) for p in merged_data["predicates"].values()])

    merged_data["rules"] = sorted(remove_redundant_rules(
        [normalize_rule(r, variables, keySet) for r in merged_data["rules"]]))
    merged_data["domain"] = list(entry["domain"])
    final_output = [merged_data]

    # Write the processed data_original to output file
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(final_output, f, indent=4)

    configuration["variables"] = len(variables)
    configuration["predicates"] = len(merged_data["predicates"])

    return final_output, configuration