import os
import sys
project_root = os.path.dirname(
    os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)

from synthesizer.fact import AttributeFact, RelationFact
from synthesizer.rule import Rule
from synthesizer.expression import ConstantExpression, IdentityExpression, LinearExpression, BinaryExpression, PlaceHolder
from synthesizer.template import TemplateFactory
from synthesizer.reasoning_graph import ReasoningGraph, ReasoningNode
from synthesizer.theory_utils import trigger_all_rules
import random

default_expression_weights = {
    "normal": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1, BinaryExpression: 1},
    "binary": {ConstantExpression: 1, IdentityExpression: 1, LinearExpression: 1}
}


class Theory:
    def __init__(self, template_factory, entities, attributes, relations, fact_num=10, rule_num=10, depth=4, interval=(-10, 10), condition_num_interval=(1, 1), expression_weights=default_expression_weights):
        self.template_factory = template_factory
        self.entities = entities
        self.attributes = attributes
        self.relations = relations
        self.fact_num = fact_num
        self.rule_num = rule_num
        self.depth = depth
        assert interval[0] <= interval[1], "Interval should be valid"
        self.interval = interval
        assert condition_num_interval[0] <= condition_num_interval[1], "Condition number interval should be valid"
        self.condition_num_interval = condition_num_interval
        assert all(weight >= 0 for weight in expression_weights["normal"].values()), \
            "All normal expression weights should be non-negative"
        assert all(weight >= 0 for weight in expression_weights["binary"].values()), \
            "All binary expression weights should be non-negative"
        assert BinaryExpression not in expression_weights["binary"] or expression_weights["binary"][BinaryExpression] == 0, \
            "BinaryExpression should not be in binary expression weights if it is not allowed"
        self.normal_expression_weights = expression_weights["normal"]
        self.binary_expression_weights = expression_weights["binary"]
        while True:
            self._construct_reasoning_graph()
            self._do_topo_reasoning()
            trigger_result = trigger_all_rules(
                self.facts, self.rules, self.values, self.exist_relations, self.debug_info())
            if trigger_result is not None:
                self.values, self.exist_relations = trigger_result
                break
            else:
                print("Conflict detected while construct reasoning graph, retrying...")

        self._synthesize_irrelevant()
        self.reasoning_process = self._process_reasoning_steps()

    def _construct_reasoning_graph(self):
        self.facts = []
        self.rules = []
        self.values = {entity: {attribute: None for attribute in self.attributes}
                       for entity in self.entities}
        self.exist_relations = {entity1: {entity2: [] for entity2 in self.entities}
                                for entity1 in self.entities}
        self.reasoning_graph = ReasoningGraph()
        self.query = (random.choice(self.entities),
                      random.choice(self.attributes))
        self.values[self.query[0]][self.query[1]] = PlaceHolder
        root_conclusion = AttributeFact(
            self.query[0], self.query[1], ConstantExpression(-1), self.template_factory)
        root_node = ReasoningNode(root_conclusion, None)
        self.reasoning_graph.add_edge(root_node, None)
        self.unprocessed = [root_node]
        self.left_depth = self.depth

        while self.unprocessed:
            reasoning_node = self.unprocessed.pop(0)
            assert reasoning_node.support is None, "Support should be None for unprocessed node"
            fact = reasoning_node.conclusion
            if isinstance(fact, AttributeFact):
                assert self.values[fact.entity][fact.attribute] is PlaceHolder, "Attribute value should be set before synthesis"
                assert isinstance(
                    fact.expression, ConstantExpression), "Expression should be a ConstantExpression"
                if self.left_depth == 0 or (len(self.unprocessed) > 0 and random.random() < 0.5):
                    self.synthesize_attribute_fact(reasoning_node)
                else:
                    self.synthesize_attribute_rule(reasoning_node)
            elif isinstance(fact, RelationFact):
                assert self.exist_relations[fact.entity1][fact.entity2], "Relation should exist before synthesis"
                if self.left_depth == 0 or (len(self.unprocessed) > 0 and random.random() < 0.5):
                    self.synthesize_relation_fact(reasoning_node)
                else:
                    self.synthesize_relation_rule(reasoning_node)

    def synthesize_attribute_fact(self, reasoning_node):
        target_fact = reasoning_node.conclusion
        target_fact.expression.value = random.randint(
            self.interval[0], self.interval[1])
        self.values[target_fact.entity][target_fact.attribute] = target_fact.expression.value
        self.facts.append(target_fact)
        reasoning_node.support = target_fact

    def synthesize_relation_fact(self, reasoning_node):
        target_fact = reasoning_node.conclusion
        assert target_fact.entity1 != target_fact.entity2, "Entities in relation fact should be different"
        self.facts.append(target_fact)
        reasoning_node.support = target_fact

    def synthesize_attribute_rule(self, reasoning_node):
        def synthesize_expression(condition_entities, from_binary=False):
            if from_binary:
                expression_type = random.choices(list(self.binary_expression_weights.keys()),
                                                 weights=list(self.binary_expression_weights.values()), k=1)[0]
            else:
                expression_type = random.choices(list(self.normal_expression_weights.keys()),
                                                 weights=list(self.normal_expression_weights.values()), k=1)[0]
            if expression_type == ConstantExpression:
                return ConstantExpression.synthesize_from_interval(self.interval)
            elif expression_type == IdentityExpression:
                expression_condition = self.synthesize_attribute_condition(
                    reasoning_node, target_entity=random.choice(condition_entities))
                entity = expression_condition.entity
                return IdentityExpression.synthesize_from_interval(
                    entity, expression_condition.attribute, self.template_factory)
            elif expression_type == LinearExpression:
                expression_condition = self.synthesize_attribute_condition(
                    reasoning_node, target_entity=random.choice(condition_entities))
                entity = expression_condition.entity
                return LinearExpression.synthesize_from_interval(
                    self.interval, entity, expression_condition.attribute, self.template_factory)
            elif expression_type == BinaryExpression:
                expression1 = synthesize_expression(
                    condition_entities, from_binary=True)
                expression2 = synthesize_expression(
                    condition_entities, from_binary=True)
                return BinaryExpression.synthesize_from_interval(
                    expression1, expression2, self.template_factory)
            else:
                raise ValueError("Invalid expression type")

        def synthesize_true_conditions(reasoning_node):
            condition_num = random.randint(
                self.condition_num_interval[0], self.condition_num_interval[1])
            conditions = []
            entity_list = [reasoning_node.conclusion.entity]
            necessity = False
            while len(conditions) < condition_num:
                condition_type = random.choice([AttributeFact, RelationFact])
                if condition_type == AttributeFact:
                    if len(conditions) == condition_num - 1 and not necessity:
                        condition_entity = reasoning_node.conclusion.entity
                    else:
                        condition_entity = random.choice(
                            entity_list + [random.choice(list(set(self.entities) - set(entity_list)))])
                    condition = self.synthesize_attribute_condition(
                        reasoning_node, target_entity=condition_entity)
                    conditions.append(condition)
                    necessity = necessity or (
                        condition.entity == reasoning_node.conclusion.entity)
                    if condition.entity not in entity_list:
                        entity_list.append(condition.entity)
                elif condition_type == RelationFact:
                    if len(conditions) == condition_num - 1 and not necessity:
                        condition_entity1 = reasoning_node.conclusion.entity
                        condition_entity2 = random.choice(
                            list(set(entity_list + [random.choice(list(set(self.entities) - set(entity_list)))]) - {condition_entity1}))
                    else:
                        condition_entity1, condition_entity2 = random.sample(
                            entity_list + [random.choice(list(set(self.entities) - set(entity_list)))], 2)
                    condition = self.synthesize_relation_condition(
                        reasoning_node, target_entity1=condition_entity1, target_entity2=condition_entity2)
                    conditions.append(condition)
                    necessity = necessity or (
                        condition.entity1 == reasoning_node.conclusion.entity or condition.entity2 == reasoning_node.conclusion.entity)
                    if condition.entity1 not in entity_list:
                        entity_list.append(condition.entity1)
                    if condition.entity2 not in entity_list:
                        entity_list.append(condition.entity2)
                else:
                    raise ValueError(
                        "Invalid condition type at synthesis_attribute_rule-synthesize_true_conditions")
            random.shuffle(conditions)
            return conditions

        assert self.left_depth > 0, "Depth should be greater than 0"
        assert isinstance(
            reasoning_node.conclusion, AttributeFact), "reasoning node's conclusion should be an AttributeFact"
        self.left_depth -= 1

        conditions = synthesize_true_conditions(reasoning_node)

        condition_entities = list()
        for condition in conditions:
            if isinstance(condition, AttributeFact):
                if condition.entity not in condition_entities:
                    condition_entities.append(condition.entity)
            elif isinstance(condition, RelationFact):
                if condition.entity1 not in condition_entities:
                    condition_entities.append(condition.entity1)
                if condition.entity2 not in condition_entities:
                    condition_entities.append(condition.entity2)

        assert reasoning_node.conclusion.entity in condition_entities, \
            f"Conclusion entity should be in condition entities {condition_entities}, but got {reasoning_node.conclusion.entity}\nconditions: {conditions}"

        condition_entity_map = {
            entity: f"entity_{i+1}" for i, entity in enumerate(condition_entities)}
        condition_entity_rmap = {
            f"entity_{i+1}": entity for i, entity in enumerate(condition_entities)}

        rule_conditions = []
        for condition in conditions:
            if isinstance(condition, AttributeFact):
                condition_entity = condition_entity_map[condition.entity]
                rule_condition = AttributeFact(
                    condition_entity, condition.attribute, ConstantExpression(-1), self.template_factory)
            elif isinstance(condition, RelationFact):
                condition_entity1 = condition_entity_map[condition.entity1]
                condition_entity2 = condition_entity_map[condition.entity2]
                rule_condition = RelationFact(
                    condition.relation, condition_entity1, condition_entity2, self.template_factory)
            rule_conditions.append(rule_condition)

        true_expression = synthesize_expression(condition_entities)
        reasoning_node.conclusion.expression = true_expression
        expression = true_expression.substitute_entity(
            condition_entity_map)
        rule_conclusion = AttributeFact(
            condition_entity_map[reasoning_node.conclusion.entity],
            reasoning_node.conclusion.attribute,
            expression,
            self.template_factory
        )
        rule = Rule(rule_conditions, rule_conclusion, self.template_factory)
        reasoning_node.support = rule
        reasoning_node.true_conditions = conditions
        self.rules.append(rule)

    def synthesize_relation_rule(self, reasoning_node):
        def synthesize_true_conditions(reasoning_node):
            condition_num = random.randint(
                self.condition_num_interval[0], self.condition_num_interval[1])
            conditions = []
            entity_list = [reasoning_node.conclusion.entity1,
                           reasoning_node.conclusion.entity2]
            necessary_entities = set(entity_list)
            while len(conditions) < condition_num:
                condition_type = random.choice([AttributeFact, RelationFact])
                if len(conditions) == condition_num - 1 and len(necessary_entities) > 0:
                    # Ensure all entities in conclusion are included in conditions
                    if len(necessary_entities) == 2:
                        condition = self.synthesize_relation_condition(
                            reasoning_node, target_entity1=reasoning_node.conclusion.entity1, target_entity2=reasoning_node.conclusion.entity2)
                        conditions.append(condition)
                        necessary_entities.difference_update(
                            {condition.entity1, condition.entity2})
                    elif len(necessary_entities) == 1 and condition_type == AttributeFact:
                        condition = self.synthesize_attribute_condition(
                            reasoning_node, target_entity=next(iter(necessary_entities)))
                        conditions.append(condition)
                        necessary_entities.discard(condition.entity)
                    elif len(necessary_entities) == 1 and condition_type == RelationFact:
                        condition_entity1 = next(iter(necessary_entities))
                        condition_entity2 = random.choice(
                            list(set(entity_list + [random.choice(list(set(self.entities) - set(entity_list)))]) - {condition_entity1}))
                        condition = self.synthesize_relation_condition(
                            reasoning_node, target_entity1=condition_entity1, target_entity2=condition_entity2)
                        conditions.append(condition)
                        necessary_entities.difference_update(
                            {condition.entity1, condition.entity2})
                        if condition.entity1 not in entity_list:
                            entity_list.append(condition.entity1)
                        if condition.entity2 not in entity_list:
                            entity_list.append(condition.entity2)
                    else:
                        raise ValueError(
                            "Invalid condition type at synthesize_relation_rule-synthesize_true_conditions")
                elif condition_type == AttributeFact:
                    condition_entity = random.choice(
                        entity_list + [random.choice(list(set(self.entities) - set(entity_list)))])
                    condition = self.synthesize_attribute_condition(
                        reasoning_node, target_entity=condition_entity)
                    conditions.append(condition)
                    necessary_entities.discard(condition.entity)
                    if condition.entity not in entity_list:
                        entity_list.append(condition.entity)
                elif condition_type == RelationFact:
                    condition_entity1, condition_entity2 = random.sample(
                        entity_list + [random.choice(list(set(self.entities) - set(entity_list)))], 2)
                    condition = self.synthesize_relation_condition(
                        reasoning_node, target_entity1=condition_entity1, target_entity2=condition_entity2)
                    conditions.append(condition)
                    necessary_entities.difference_update(
                        {condition.entity1, condition.entity2})
                    if condition.entity1 not in entity_list:
                        entity_list.append(condition.entity1)
                    if condition.entity2 not in entity_list:
                        entity_list.append(condition.entity2)
                else:
                    raise ValueError(
                        "Invalid condition type at synthesize_relation_rule-synthesize_true_conditions")
            assert len(necessary_entities) == 0, \
                f"All entities in conclusion should be included in conditions, but got {necessary_entities}"
            random.shuffle(conditions)
            return conditions

        assert self.left_depth > 0, "Depth should be greater than 0"
        assert isinstance(
            reasoning_node.conclusion, RelationFact), "reasoning node's conclusion should be a RelationFact"
        self.left_depth -= 1
        conditions = synthesize_true_conditions(reasoning_node)
        condition_entities = list()
        for condition in conditions:
            if isinstance(condition, AttributeFact):
                if condition.entity not in condition_entities:
                    condition_entities.append(condition.entity)
            elif isinstance(condition, RelationFact):
                if condition.entity1 not in condition_entities:
                    condition_entities.append(condition.entity1)
                if condition.entity2 not in condition_entities:
                    condition_entities.append(condition.entity2)

        assert reasoning_node.conclusion.entity1 in condition_entities and \
            reasoning_node.conclusion.entity2 in condition_entities, \
            f"Conclusion entities {reasoning_node.conclusion.entity1} and {reasoning_node.conclusion.entity2} should be in condition entities {condition_entities}\nconditions: {conditions}"

        assert len(
            condition_entities) >= 2, "At least two entities are required for relation rule synthesis"
        condition_entity_map = {
            entity: f"entity_{i+1}" for i, entity in enumerate(condition_entities)}
        condition_entity_rmap = {
            f"entity_{i+1}": entity for i, entity in enumerate(condition_entities)}
        rule_conditions = []
        for condition in conditions:
            if isinstance(condition, AttributeFact):
                condition_entity = condition_entity_map[condition.entity]
                rule_condition = AttributeFact(
                    condition_entity, condition.attribute, ConstantExpression(-1), self.template_factory)
            elif isinstance(condition, RelationFact):
                condition_entity1 = condition_entity_map[condition.entity1]
                condition_entity2 = condition_entity_map[condition.entity2]
                rule_condition = RelationFact(
                    condition.relation, condition_entity1, condition_entity2, self.template_factory)
            rule_conditions.append(rule_condition)
        rule_conclusion = RelationFact(
            reasoning_node.conclusion.relation,
            condition_entity_map[reasoning_node.conclusion.entity1],
            condition_entity_map[reasoning_node.conclusion.entity2],
            self.template_factory
        )
        rule = Rule(rule_conditions, rule_conclusion, self.template_factory)
        reasoning_node.support = rule
        reasoning_node.true_conditions = conditions
        self.rules.append(rule)

    def synthesize_attribute_condition(self, father_node, target_entity=None, target_attribute=None):
        while True:
            entity = target_entity if target_entity else random.choice(
                self.entities)
            attribute = target_attribute if target_attribute else random.choice(
                self.attributes)
            condition = AttributeFact(
                entity, attribute, ConstantExpression(-1), self.template_factory)
            condition_node = ReasoningNode(condition, None)
            if self.reasoning_graph.test_if_dag(condition_node, father_node):
                if condition_node not in self.reasoning_graph.nodes:
                    assert self.values[entity][attribute] is None, \
                        f"Attribute value {self.values[entity][attribute]} should be None before synthesis\n{self.debug_info()}"
                    self.unprocessed.append(condition_node)
                    self.values[entity][attribute] = PlaceHolder
                self.reasoning_graph.add_edge(condition_node, father_node)
                break
        return condition

    def synthesize_relation_condition(self, father_node, target_entity1=None, target_entity2=None, target_relation=None):
        if target_entity1 is not None and target_entity2 is not None:
            assert target_entity1 != target_entity2, "Target entities in relation condition should be different"
        while True:
            entity1 = target_entity1 if target_entity1 else random.choice(
                self.entities)
            entity2 = target_entity2 if target_entity2 else random.choice(
                self.entities)
            if entity1 == entity2:
                continue
            if random.random() < 0.5:
                entity1, entity2 = entity2, entity1
            relation = target_relation if target_relation else random.choice(
                self.relations)
            condition = RelationFact(
                relation, entity1, entity2, self.template_factory)
            condition_node = ReasoningNode(condition, None)
            if self.reasoning_graph.test_if_dag(condition_node, father_node):
                if condition_node not in self.reasoning_graph.nodes:
                    assert relation not in self.exist_relations[entity1][entity2], \
                        f"Relation {relation} should not exist before synthesis\n{self.debug_info()}"
                    self.unprocessed.append(condition_node)
                    self.exist_relations[entity1][entity2].append(relation)
                self.reasoning_graph.add_edge(condition_node, father_node)
                break
        return condition

    def _do_topo_reasoning(self):
        topological_order = self.reasoning_graph.topo_sort()
        for node in topological_order:
            conclusion = node.conclusion
            support = node.support
            if isinstance(support, Rule) and isinstance(conclusion, AttributeFact):
                compute_args = conclusion.expression.parse_compute_args(
                    self.values)
                true_value = conclusion.expression.compute(
                    **compute_args)
                assert self.values[conclusion.entity][
                    conclusion.attribute] is PlaceHolder, f"Attribute value {self.values[conclusion.entity][conclusion.attribute]} should be PlaceHolder before topo reasoning\n{self.debug_info()}"
                self.values[conclusion.entity][conclusion.attribute] = true_value
            elif not isinstance(support, Rule) and isinstance(conclusion, AttributeFact):
                # self.values[conclusion.entity][conclusion.attribute] = conclusion.expression.value
                # this has been done in synthesize_attribute_fact
                pass

            if isinstance(conclusion, AttributeFact):
                value = self.values[conclusion.entity][conclusion.attribute]
                assert value is not None and value is not PlaceHolder, \
                    f"Attribute value {value} should not be None after topo reasoning\n{self.debug_info()}"

                for parent_node in self.reasoning_graph.edges[node]:
                    for i, true_condition in enumerate(parent_node.true_conditions):
                        if true_condition == conclusion:
                            parent_node.true_conditions[i].expression.value = value
                            parent_node.support.conditions[i].expression.value = value

    def _synthesize_irrelevant(self):
        def validate_fact(fact):
            if fact in self.facts:
                return False
            trigger_result = trigger_all_rules(
                self.facts + [fact], self.rules, self.values, self.exist_relations, self.debug_info())
            if trigger_result is None:
                return False
            else:
                self.values, self.exist_relations = trigger_result
                return True

        def validate_rule(rule):
            trigger_result = trigger_all_rules(
                self.facts, self.rules + [rule], self.values, self.exist_relations, self.debug_info())
            if trigger_result is None:
                return False
            else:
                self.values, self.exist_relations = trigger_result
                return True

        def substitute_entity(fact, entity_map):
            if isinstance(fact, AttributeFact):
                return AttributeFact(
                    entity_map[fact.entity],
                    fact.attribute,
                    fact.expression.substitute_entity(entity_map),
                    self.template_factory
                )
            elif isinstance(fact, RelationFact):
                return RelationFact(
                    fact.relation,
                    entity_map[fact.entity1],
                    entity_map[fact.entity2],
                    self.template_factory
                )
            else:
                raise ValueError(f"Invalid fact type: {type(fact)}")

        def synthesize_irrelevant_expression(entity_list, conclusion, from_binary=False):
            if from_binary:
                expression_type = random.choices(list(self.binary_expression_weights.keys()),
                                                 weights=list(self.binary_expression_weights.values()), k=1)[0]
            else:
                expression_type = random.choices(list(self.normal_expression_weights.keys()),
                                                 weights=list(self.normal_expression_weights.values()), k=1)[0]
            if expression_type == ConstantExpression:
                return ConstantExpression.synthesize_from_interval(self.interval)
            elif expression_type == IdentityExpression:
                entity = random.choice(entity_list)
                if entity == conclusion.entity:
                    attribute = random.choice(
                        list(set(self.attributes) - {conclusion.attribute}))
                else:
                    attribute = random.choice(self.attributes)
                return IdentityExpression.synthesize_from_interval(
                    entity, attribute, self.template_factory)
            elif expression_type == LinearExpression:
                entity = random.choice(entity_list)
                if entity == conclusion.entity:
                    attribute = random.choice(
                        list(set(self.attributes) - {conclusion.attribute}))
                else:
                    attribute = random.choice(self.attributes)
                return LinearExpression.synthesize_from_interval(
                    self.interval, entity, attribute, self.template_factory)
            elif expression_type == BinaryExpression:
                expression1 = synthesize_irrelevant_expression(
                    entity_list, conclusion, from_binary=True)
                expression2 = synthesize_irrelevant_expression(
                    entity_list, conclusion, from_binary=True)
                return BinaryExpression.synthesize_from_interval(
                    expression1, expression2, self.template_factory)
            else:
                raise ValueError("Invalid expression type")

        def synthesize_irrelevant_attribute_rule():
            conclusion = AttributeFact(
                entity="x_1",
                attribute=random.choice(self.attributes),
                expression=ConstantExpression(-1),
                template_factory=self.template_factory
            )
            entity_list = [conclusion.entity]
            condition_num = random.randint(
                self.condition_num_interval[0], self.condition_num_interval[1])
            conditions = []
            necessity = False
            while len(conditions) < condition_num:
                condition_type = random.choice(
                    [AttributeFact, RelationFact])
                if condition_type == AttributeFact:
                    condition_entity = random.choice(
                        entity_list + [f"x_{len(entity_list)+1}"])
                    if len(conditions) == condition_num - 1 and not necessity:
                        condition_entity = conclusion.entity
                    if condition_entity == conclusion.entity:
                        condition_attribute = random.choice(
                            list(set(self.attributes) - {conclusion.attribute}))
                    else:
                        condition_attribute = random.choice(self.attributes)
                    condition = AttributeFact(
                        condition_entity, condition_attribute, ConstantExpression.synthesize_from_interval(self.interval), self.template_factory)
                    if condition not in conditions and condition != conclusion:
                        conditions.append(condition)
                        necessity = necessity or (
                            condition_entity == conclusion.entity)
                        if condition_entity not in entity_list:
                            entity_list.append(condition_entity)
                elif condition_type == RelationFact:
                    condition_entity1, condition_entity2 = random.sample(
                        entity_list + [f"x_{len(entity_list)+1}"], 2)
                    if len(conditions) == condition_num - 1 and not necessity and condition_entity1 != conclusion.entity and condition_entity2 != conclusion.entity:
                        if random.random() < 0.5:
                            condition_entity1 = conclusion.entity
                        else:
                            condition_entity2 = conclusion.entity
                    assert condition_entity1 != condition_entity2, \
                        "Entities in relation condition should be different"
                    condition_relation = random.choice(self.relations)
                    condition = RelationFact(
                        condition_relation, condition_entity1, condition_entity2, self.template_factory)
                    if condition not in conditions:
                        conditions.append(condition)
                        necessity = necessity or (
                            condition_entity1 == conclusion.entity or condition_entity2 == conclusion.entity)
                        if condition_entity1 not in entity_list:
                            entity_list.append(condition_entity1)
                        if condition_entity2 not in entity_list:
                            entity_list.append(condition_entity2)
                else:
                    raise ValueError(
                        f"Invalid condition type: {condition_type}")
            expression = synthesize_irrelevant_expression(
                entity_list, conclusion)
            conclusion.expression = expression
            random.shuffle(conditions)
            return conditions, conclusion

        def synthesize_irrelevant_relation_rule():
            conclusion = RelationFact(
                relation=random.choice(self.relations),
                entity1="x_1",
                entity2="x_2",
                template_factory=self.template_factory
            )
            entity_list = [conclusion.entity1, conclusion.entity2]
            condition_num = random.randint(
                self.condition_num_interval[0], self.condition_num_interval[1])
            necessary_entities = set(entity_list)
            conditions = []
            while len(conditions) < condition_num:
                condition_type = random.choice(
                    [AttributeFact, RelationFact])
                if len(conditions) == condition_num - 1 and len(necessary_entities) > 0:
                    # Ensure all entities in conclusion are included in conditions
                    if len(necessary_entities) == 2:
                        condition_relation = random.choice(
                            list(set(self.relations) - {conclusion.relation}))
                        condition_entity1, condition_entity2 = conclusion.entity1, conclusion.entity2
                        if random.random() < 0.5:
                            condition_entity1, condition_entity2 = condition_entity2, condition_entity1
                        condition = RelationFact(
                            condition_relation,
                            condition_entity1,
                            condition_entity2,
                            self.template_factory)
                        if condition not in conditions and condition != conclusion:
                            conditions.append(condition)
                            necessary_entities.difference_update(
                                {condition.entity1, condition.entity2})
                    elif len(necessary_entities) == 1 and condition_type == AttributeFact:
                        condition_entity = next(iter(necessary_entities))
                        condition_attribute = random.choice(self.attributes)
                        condition = AttributeFact(
                            condition_entity, condition_attribute, ConstantExpression.synthesize_from_interval(self.interval), self.template_factory)
                        if condition not in conditions:
                            conditions.append(condition)
                            necessary_entities.discard(condition.entity)
                    elif len(necessary_entities) == 1 and condition_type == RelationFact:
                        condition_entity1 = next(iter(necessary_entities))
                        condition_entity2 = random.choice(
                            list(set(entity_list + [f"x_{len(entity_list)+1}"]) - {condition_entity1}))
                        if random.random() < 0.5:
                            condition_entity1, condition_entity2 = condition_entity2, condition_entity1
                        assert condition_entity1 != condition_entity2, \
                            "Entities in relation condition should be different"
                        if condition_entity1 == conclusion.entity1 and condition_entity2 == conclusion.entity2:
                            condition_relation = random.choice(
                                list(set(self.relations) - {conclusion.relation}))
                        else:
                            condition_relation = random.choice(self.relations)
                        condition = RelationFact(
                            condition_relation, condition_entity1, condition_entity2, self.template_factory)
                        if condition not in conditions and condition != conclusion:
                            conditions.append(condition)
                            necessary_entities.difference_update(
                                {condition.entity1, condition.entity2})
                            if condition_entity1 not in entity_list:
                                entity_list.append(condition_entity1)
                            if condition_entity2 not in entity_list:
                                entity_list.append(condition_entity2)
                    else:
                        raise ValueError(
                            "Invalid condition type at synthesize_irrelevant_relation_rule-synthesize_irrelevant_conditions")
                elif condition_type == AttributeFact:
                    condition_entity = random.choice(
                        entity_list + [f"x_{len(entity_list)+1}"])
                    condition_attribute = random.choice(self.attributes)
                    condition = AttributeFact(
                        condition_entity, condition_attribute, ConstantExpression.synthesize_from_interval(self.interval), self.template_factory)
                    if condition not in conditions:
                        conditions.append(condition)
                        necessary_entities.discard(condition.entity)
                        if condition_entity not in entity_list:
                            entity_list.append(condition_entity)
                elif condition_type == RelationFact:
                    condition_entity1, condition_entity2 = random.sample(
                        entity_list + [f"x_{len(entity_list)+1}"], 2)
                    assert condition_entity1 != condition_entity2, \
                        "Entities in relation condition should be different"
                    if condition_entity1 == conclusion.entity1 and condition_entity2 == conclusion.entity2:
                        condition_relation = random.choice(
                            list(set(self.relations) - {conclusion.relation}))
                    else:
                        condition_relation = random.choice(self.relations)
                    condition = RelationFact(
                        condition_relation, condition_entity1, condition_entity2, self.template_factory)
                    if condition not in conditions:
                        conditions.append(condition)
                        necessary_entities.difference_update(
                            {condition.entity1, condition.entity2})
                        if condition_entity1 not in entity_list:
                            entity_list.append(condition_entity1)
                        if condition_entity2 not in entity_list:
                            entity_list.append(condition_entity2)
                else:
                    raise ValueError(
                        f"Invalid condition type: {condition_type}")
            random.shuffle(conditions)
            return conditions, conclusion

        while len(self.facts) < self.fact_num:
            fact_type = random.choice([AttributeFact, RelationFact])
            if fact_type == AttributeFact:
                entity = random.choice(self.entities)
                attribute = random.choice(self.attributes)
                if self.values[entity][attribute] is not None:
                    continue
                expression = ConstantExpression.synthesize_from_interval(
                    self.interval)
                fact = AttributeFact(
                    entity, attribute, expression, self.template_factory)
                if validate_fact(fact):
                    self.facts.append(fact)
            elif fact_type == RelationFact:
                entity1, entity2 = random.sample(self.entities, 2)
                relation = random.choice(self.relations)
                if relation in self.exist_relations[entity1][entity2]:
                    continue
                fact = RelationFact(
                    relation, entity1, entity2, self.template_factory)
                if validate_fact(fact):
                    self.facts.append(fact)
            else:
                raise ValueError(f"Invalid fact type: {fact_type}")

        while len(self.rules) < self.rule_num:
            conclusion_type = random.choice([AttributeFact, RelationFact])
            if conclusion_type == AttributeFact:
                conditions, conclusion = synthesize_irrelevant_attribute_rule()
            elif conclusion_type == RelationFact:
                conditions, conclusion = synthesize_irrelevant_relation_rule()
            else:
                raise ValueError(f"Invalid conclusion type: {conclusion_type}")
            entity_list = []
            for condition in conditions:
                if isinstance(condition, AttributeFact):
                    if condition.entity not in entity_list:
                        entity_list.append(condition.entity)
                elif isinstance(condition, RelationFact):
                    if condition.entity1 not in entity_list:
                        entity_list.append(condition.entity1)
                    if condition.entity2 not in entity_list:
                        entity_list.append(condition.entity2)
            entity_map = {entity: f"entity_{i+1}" for i,
                          entity in enumerate(entity_list)}
            conditions = [substitute_entity(condition, entity_map)
                          for condition in conditions]
            conclusion = substitute_entity(conclusion, entity_map)
            rule = Rule(conditions, conclusion, self.template_factory)
            if validate_rule(rule):
                self.rules.append(rule)

        random.shuffle(self.facts)
        random.shuffle(self.rules)

    def _process_reasoning_steps(self):
        def index_rule(rule):
            for i, r in enumerate(self.rules):
                if r == rule:
                    return ("rule", i + 1)
            assert False, f"Rule {rule} not found in rules list {self.rules}"

        def index_fact(fact):
            for i, f in enumerate(self.facts):
                if f == fact:
                    return ("fact", i + 1)
            for i, it in enumerate([step[1] for step in reasoning_steps]):
                if it == fact:
                    return ("int", i + 1)
            assert False, f"Fact {fact} not found in facts list {self.facts} or reasoning steps {reasoning_steps}"

        topological_order = self.reasoning_graph.topo_sort()
        reasoning_steps = []
        for current_node in topological_order:
            support = current_node.support
            if not isinstance(support, Rule):
                continue
            condition_list = [index_rule(support)]
            for condition in [pre_node.conclusion for pre_node in self.reasoning_graph.redges[current_node]]:
                condition_list.append(index_fact(condition))
            step = (condition_list, current_node.conclusion)
            reasoning_steps.append(step)
        return reasoning_steps

    def debug_info(self):
        import json
        return json.dumps({
            "entities": self.entities,
            "attributes": self.attributes,
            "relations": self.relations,
            "facts": [str(fact) for fact in self.facts],
            "rules": [str(rule) for rule in self.rules],
            "query": self.query,
            "values": self.values,
            "exist_relations": self.exist_relations
        })

    def to_json(self):
        def process_sentence(x): return x[0].upper() + x[1:] + "."
        facts_nl = list(
            map(process_sentence, [fact.nl() for fact in self.facts]))
        rules_nl = list(
            map(process_sentence, [rule.nl() for rule in self.rules]))
        apply_condition_num = sum(len(
            self.rules[process[0][0][1] - 1].conditions) for process in self.reasoning_process)
        all_condition_num = sum([len(rule.conditions) for rule in self.rules])
        return {
            "configs": {
                "entity_num": len(self.entities),
                "attribute_num": len(self.attributes),
                "relation_num": len(self.relations),
                "fact_num": self.fact_num,
                "rule_num": self.rule_num,
                "depth": self.depth
            },
            "entities": self.entities,
            "attributes": self.attributes,
            "relations": self.relations,
            "facts-nl": facts_nl,
            "rules-nl": rules_nl,
            "facts-str": [str(fact) for fact in self.facts],
            "rules-str": [str(rule) for rule in self.rules],
            "facts-repr": [repr(fact) for fact in self.facts],
            "rules-repr": [repr(rule) for rule in self.rules],
            "query": self.query,
            "answer": self.values[self.query[0]][self.query[1]],
            "apply_condition_num": apply_condition_num,
            "all_condition_num": all_condition_num,
            "reasoning_process_nl": self.nl_reasoning_process(),
            "reasoning_process": [(process[0], repr(process[1]), str(process[1]))
                                  for process in self.reasoning_process],
            "values": self.values,
            "exist_relations": self.exist_relations,
        }

    def nl_reasoning_process(self):
        def standard_nl(fact):
            if isinstance(fact, AttributeFact):
                args = fact.expression.parse_compute_args(self.values)
                computation_process = fact.expression.computation_process(
                    **args)
                return f"{fact.entity}'s {fact.attribute} is {computation_process}"
            elif isinstance(fact, RelationFact):
                return f"{fact.relation} exists between {fact.entity1} and {fact.entity2}"
            else:
                raise ValueError(f"Invalid fact type: {type(fact)}")
        reasoning_steps = []
        for i, step in enumerate(self.reasoning_process):
            conditions = step[0]
            conclusion = step[1]
            conditions_str = " & ".join(
                [f"{c[0]}_{c[1]}" for c in conditions])
            step = f"{conditions_str} => int_{i+1}: {standard_nl(conclusion)}."
            reasoning_steps.append(step)
        return "\n".join(reasoning_steps)
