from synthesizer.expression import ConstantExpression, IdentityExpression, LinearExpression, BinaryExpression
from synthesizer.fact import AttributeFact, RelationFact
from synthesizer.rule import Rule
from synthesizer.template import TemplateFactory
from synthesizer.reasoning_graph import ReasoningGraph, ReasoningNode
import sys
import os
import re
import copy
import json
import logging
from typing import Dict, List, Tuple, Optional, Union, Any
from datetime import datetime
import spacy
from functools import lru_cache

nlp = spacy.load("en_core_web_sm")


@lru_cache(maxsize=1000)
def lemmatize_word_spacy(word: str) -> str:
    """Lemmatize words using spacy with caching mechanism"""
    if nlp is None:
        return word.lower()  # If no spacy model available, simply convert to lowercase
    return nlp(word)[0].lemma_.lower()


project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)


# Global variables to ensure configuration happens only once
_logger_configured = False
_global_handlers = None

# Configure logger


def setup_logger(name: str = __name__, level: int = logging.INFO) -> logging.Logger:
    """
    Set up and return a configured logger

    Args:
        name: Logger name, usually __name__
        level: Log level, default is INFO

    Returns:
        Configured logger instance
    """
    global _logger_configured, _global_handlers

    logger = logging.getLogger(name)

    # Prevent log propagation to root logger to avoid duplicate output
    logger.propagate = False

    # Avoid adding handlers repeatedly
    if not logger.handlers:
        # If first time configuration, create global handlers
        if not _logger_configured:
            # Create console handler
            console_handler = logging.StreamHandler()
            console_handler.setLevel(level)

            # Create file handler (optional)
            log_dir = os.path.join(os.path.dirname(__file__), 'logs')
            os.makedirs(log_dir, exist_ok=True)
            # Get current date globally once
            if not hasattr(setup_logger, "_log_date"):
                setup_logger._log_date = datetime.now().strftime("%Y%m%d-%H%M")
            log_filename = f'parser-{setup_logger._log_date}.log'
            file_handler = logging.FileHandler(
                os.path.join(log_dir, log_filename),
                encoding='utf-8'
            )
            file_handler.setLevel(logging.DEBUG)

            # Create formatter
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                datefmt='%Y-%m-%d %H:%M:%S'
            )

            # Set formatter for handlers
            console_handler.setFormatter(formatter)
            file_handler.setFormatter(formatter)

            # Save global handlers
            _global_handlers = [console_handler, file_handler]
            _logger_configured = True

        # Add handlers to current logger
        for handler in _global_handlers:
            logger.addHandler(handler)

    # Always update logger level, allow dynamic adjustment
    logger.setLevel(level)

    return logger


# Create global logger instance
logger = setup_logger()


class ExpressionParser:
    """Expression parser class"""

    def __init__(self):
        self.logger = setup_logger(f"{__name__}.{self.__class__.__name__}")

    def parse_expression(self, expr_str: str, template_factory: Optional[TemplateFactory] = None) -> Union[ConstantExpression, IdentityExpression, LinearExpression, BinaryExpression]:
        """
        Parse expression string

        Args:
            expr_str: Expression string, e.g. "ConstantExpression(6)"
            template_factory: Template factory

        Returns:
            Expression object
        """
        self.logger.debug(f"Starting to parse expression: {expr_str}")

        if template_factory is None:
            template_factory = TemplateFactory()

        # Parse ConstantExpression
        const_match = re.match(r'ConstantExpression\(([^)]+)\)', expr_str)
        if const_match:
            value = int(const_match.group(1).strip())
            self.logger.debug(f"Parsed as ConstantExpression, value: {value}")
            return ConstantExpression(value, template_factory)

        # Parse IdentityExpression
        id_match = re.match(
            r'IdentityExpression\(([^,]+),\s*([^)]+)\)', expr_str)
        if id_match:
            entity = id_match.group(1).strip()
            attribute = id_match.group(2).strip()
            self.logger.debug(
                f"Parsed as IdentityExpression, entity: {entity}, attribute: {attribute}")
            return IdentityExpression(entity, attribute, template_factory)

        # Parse LinearExpression
        linear_match = re.match(
            r'LinearExpression\(([^,]+),\s*([^,]+),\s*([^,]+),\s*([^)]+)\)', expr_str)
        if linear_match:
            coeff = int(linear_match.group(1).strip())
            const = int(linear_match.group(2).strip())
            entity = linear_match.group(3).strip()
            attribute = linear_match.group(4).strip()
            self.logger.debug(
                f"Parsed as LinearExpression, coefficient: {coeff}, constant: {const}, entity: {entity}, attribute: {attribute}")
            return LinearExpression(coeff, const, entity, attribute, template_factory)

        # Parse BinaryExpression
        # Format: BinaryExpression(operation, expr1, expr2)
        binary_match = re.match(
            r'BinaryExpression\(([^,]+),\s*(.+),\s*(.+)\)', expr_str)
        if binary_match:
            operation = binary_match.group(1).strip()
            # Need more complex parsing to handle nested expressions
            expr_part = binary_match.group(2) + ', ' + binary_match.group(3)

            # Use bracket matching to split two expressions
            expr1_str, expr2_str = self._split_binary_expressions(expr_part)

            # Recursively parse two sub-expressions
            expr1 = self.parse_expression(expr1_str.strip(), template_factory)
            expr2 = self.parse_expression(expr2_str.strip(), template_factory)

            self.logger.debug(
                f"Parsed as BinaryExpression, operation: {operation}, expression1: {expr1_str.strip()}, expression2: {expr2_str.strip()}")
            return BinaryExpression(expr1, expr2, operation, template_factory)

        self.logger.error(f"Unable to parse expression: {expr_str}")
        raise ValueError(f"Unable to parse expression: {expr_str}")

    def _split_binary_expressions(self, expr_part: str) -> Tuple[str, str]:
        """
        Split two sub-expressions in BinaryExpression

        Args:
            expr_part: String containing two expressions, e.g. "ConstantExpression(5), LinearExpression(2, 3, entity, attr)"

        Returns:
            Tuple[str, str]: Two expression strings
        """
        bracket_depth = 0
        for i, char in enumerate(expr_part):
            if char == '(':
                bracket_depth += 1
            elif char == ')':
                bracket_depth -= 1
            elif char == ',' and bracket_depth == 0:
                # Found the split point
                expr1 = expr_part[:i]
                expr2 = expr_part[i+1:]
                return expr1.strip(), expr2.strip()

        # If no split point found, format may be incorrect
        raise ValueError(f"Unable to split binary expression: {expr_part}")

    @staticmethod
    def get_expression_value(expression) -> Any:
        """Get the actual value of expression"""
        if hasattr(expression, 'compute'):
            if hasattr(expression, 'value'):
                return expression.value
            else:
                return expression.compute()
        elif hasattr(expression, 'value'):
            return expression.value
        else:
            return str(expression)


class FactParser:
    """Fact parser class"""

    def __init__(self, template_factory: TemplateFactory):
        self.template_factory = template_factory
        self.expression_parser = ExpressionParser()
        self.logger = setup_logger(
            f"{__name__}.{self.__class__.__name__}", level=logging.DEBUG)

    def parse_fact_from_description(self, description: str) -> Union[AttributeFact, RelationFact]:
        """
        Args:
            description: Fact description, e.g. "Pooh's metric is 7"

        Returns:
            AttributeFact or RelationFact
        """
        self.logger.debug(f"Starting to parse fact description: {description}")

        # "somebody's something is xxx"
        if re.match(r".+'s .+ is .+", description):
            parts = description.split(" is ")
            if len(parts) == 2:
                entity_attr = parts[0].strip()
                value = parts[1].strip()

                if "'s " in entity_attr:
                    entity, attribute = entity_attr.split("'s ", 1)
                    entity = entity.strip()
                    attribute = attribute.strip()

                    try:
                        numeric_value = FactParser._expression_to_int(value)
                        expression = ConstantExpression(numeric_value)
                        self.logger.debug(
                            f"Parsed as AttributeFact: entity={entity}, attribute={attribute}, value={numeric_value}")
                        return AttributeFact(entity, attribute, expression, self.template_factory)
                    except ValueError:
                        self.logger.debug(
                            f"Unable to convert value to number: {value}")
                        pass
        else:  # RelationFact
            self.logger.debug(f"Trying to parse {description} as RelationFact")
            # hail exists between Kelila and Zoe
            if " exists between " in description:
                parts = description.split(" exists between ")
                if len(parts) == 2:
                    relation = parts[0].strip()
                    entities = parts[1].strip().split(" and ")
                    if len(entities) == 2:
                        entity1 = entities[0].strip()
                        entity2 = entities[1].strip()
                        self.logger.debug(
                            f"Parsed as RelationFact: relation={relation}, entity1={entity1}, entity2={entity2}")
                        return RelationFact(relation, entity1, entity2, self.template_factory)
                # TODO: Milzie hails Randa
        self.logger.error(f"Unable to parse fact description: {description}")
        raise ValueError(f"Unable to parse fact description: {description}")

    @staticmethod
    def _expression_to_int(expression: str) -> int:
        if '=' in expression:
            expression = expression.split('=')[-1].strip().strip('.')
        expression = int(expression)
        return expression

    def parse_fact_repr(self, fact_repr: str) -> Union[AttributeFact, RelationFact]:
        """
        Safely parse string representation of fact

        Args:
            fact_repr: String representation of fact

        Returns:
            AttributeFact or RelationFact
        """
        # Parse AttributeFact
        attr_match = re.match(
            r'AttributeFact\(([^,]+),\s*([^,]+),\s*(.+)\)', fact_repr)
        if attr_match:
            entity = attr_match.group(1).strip()
            attribute = attr_match.group(2).strip()
            expression_str = attr_match.group(3).strip()
            expression = self.expression_parser.parse_expression(
                expression_str, self.template_factory)
            return AttributeFact(entity, attribute, expression, self.template_factory)

        # Parse RelationFact
        rel_match = re.match(
            r'RelationFact\(([^,]+),\s*([^,]+),\s*([^)]+)\)', fact_repr)
        if rel_match:
            relation = rel_match.group(1).strip()
            entity1 = rel_match.group(2).strip()
            entity2 = rel_match.group(3).strip()
            return RelationFact(relation, entity1, entity2, self.template_factory)

        raise ValueError(f"Unable to parse fact representation: {fact_repr}")

    @staticmethod
    def facts_equal_strict(fact1, fact2) -> bool:
        """Strictly compare two facts, including expression values"""
        if type(fact1) != type(fact2):
            return False

        if isinstance(fact1, AttributeFact):
            return (fact1.entity == fact2.entity and
                    fact1.attribute.lower() == fact2.attribute.lower() and
                    ExpressionParser.get_expression_value(fact1.expression) ==
                    ExpressionParser.get_expression_value(fact2.expression))
        elif isinstance(fact1, RelationFact):
            # Use lemmatization to compare relation words
            return (lemmatize_word_spacy(fact1.relation) == lemmatize_word_spacy(fact2.relation) and
                    fact1.entity1 == fact2.entity1 and
                    fact1.entity2 == fact2.entity2)

        return fact1 == fact2


class RuleParser:
    """Rule parser class"""

    def __init__(self, template_factory: TemplateFactory):
        self.template_factory = template_factory
        self.fact_parser = FactParser(template_factory)

    def parse_rule_repr(self, rule_repr: str) -> Rule:
        """
        Safely parse string representation of rule

        Args:
            rule_repr: String representation of rule

        Returns:
            Rule object
        """
        rule_match = re.match(r'Rule\(\[([^\]]+)\],\s*(.+)\)', rule_repr)
        if rule_match:
            conditions_str = rule_match.group(1).strip()
            conclusion_str = rule_match.group(2).strip()

            # Parse condition list
            conditions = []
            condition_parts = self._split_by_comma_respecting_brackets(
                conditions_str)

            for condition_str in condition_parts:
                condition_str = condition_str.strip()
                condition_fact = self.fact_parser.parse_fact_repr(
                    condition_str)
                conditions.append(condition_fact)

            # Parse conclusion
            conclusion_fact = self.fact_parser.parse_fact_repr(conclusion_str)
            return Rule(conditions, conclusion_fact, self.template_factory)

        raise ValueError(f"Unable to parse rule representation: {rule_repr}")

    @staticmethod
    def _split_by_comma_respecting_brackets(text: str) -> List[str]:
        """Split text by comma while considering bracket nesting"""
        parts = []
        current_part = ""
        bracket_depth = 0

        for char in text:
            if char == '(':
                bracket_depth += 1
            elif char == ')':
                bracket_depth -= 1
            elif char == ',' and bracket_depth == 0:
                parts.append(current_part.strip())
                current_part = ""
                continue
            current_part += char

        if current_part.strip():
            parts.append(current_part.strip())

        return parts


class ReasoningGraphBuilder:
    """Reasoning graph builder class"""

    def __init__(self, template_factory: TemplateFactory):
        self.template_factory = template_factory
        self.fact_parser = FactParser(template_factory)
        self.logger = setup_logger(f"{__name__}.{self.__class__.__name__}")

    def parse_reasoning_steps_to_graph(self, reasoning_text: str, facts: List, rules: List) -> ReasoningGraph:
        """
        Parse reasoning steps text to ReasoningGraph

        Args:
            reasoning_text: Reasoning steps text
            facts: Fact list
            rules: Rule list

        Returns:
            ReasoningGraph: Built reasoning graph
        """
        graph = ReasoningGraph()
        intermediate_facts = {}

        lines = reasoning_text.strip().split('\n')

        for line in lines:
            if '=>' not in line:
                continue

            condition_part, conclusion_part = line.split('=>', 1)
            condition_part = condition_part.strip()
            conclusion_part = conclusion_part.strip()

            conditions = [cond.strip() for cond in condition_part.split('&')]

            if ':' in conclusion_part:
                int_id, fact_desc = conclusion_part.split(':', 1)
                int_id = int_id.strip()
                fact_desc = fact_desc.strip().rstrip('.')

                conclusion_fact = self.fact_parser.parse_fact_from_description(
                    fact_desc)
                conclusion_node = ReasoningNode(conclusion_fact, None)

                support_rule = None
                condition_nodes = []

                for condition in conditions:
                    if condition.startswith('rule_'):
                        rule_index = int(condition.replace('rule_', '')) - 1
                        support_rule = rules[rule_index]
                    elif condition.startswith('fact_'):
                        fact_index = int(condition.replace('fact_', '')) - 1
                        fact = facts[fact_index]
                        fact_node = ReasoningNode(fact, fact)
                        condition_nodes.append(fact_node)
                    elif condition.startswith('int_'):
                        if condition in intermediate_facts:
                            int_node = intermediate_facts[condition]
                            if int_node is not None:
                                condition_nodes.append(int_node)

                conclusion_node.support = support_rule

                if not self.verify_reasoning_process(condition_nodes, conclusion_node):
                    self.logger.debug(f"Invalid reasoning process: {line}")
                    intermediate_facts[int_id] = None
                    continue

                graph.add_node(conclusion_node)
                for cond_node in condition_nodes:
                    graph.add_node(cond_node)
                    graph.add_edge(cond_node, conclusion_node)

                if not any(int_id in other_line for other_line in lines if other_line != line):
                    graph.root = conclusion_node

                intermediate_facts[int_id] = conclusion_node

        return graph

    @staticmethod
    def extract_final_conclusion(graph: Optional[ReasoningGraph]) -> Optional[ReasoningNode]:
        """Extract final conclusion node from reasoning graph"""
        if not graph or not graph.root:
            return None
        return graph.root

    @staticmethod
    def extract_key_reasoning_nodes(graph: Optional[ReasoningGraph]) -> List:
        """Extract key nodes (reasoning results, not basic facts) from reasoning graph"""
        if not graph or not graph.nodes:
            return []

        key_nodes = []
        for node in graph.nodes:
            if node.support is not None and isinstance(node.support, Rule):
                key_nodes.append(node.conclusion)

        return key_nodes

    @staticmethod
    def condition_triggered(condition, fact):
        if type(condition) != type(fact):
            return False
        if isinstance(fact, AttributeFact):
            return (condition.attribute.lower() == fact.attribute.lower() and
                    condition.expression.value == fact.expression.value)
        elif isinstance(fact, RelationFact):
            # Use lemmatization to compare relation words
            return (lemmatize_word_spacy(condition.relation) == lemmatize_word_spacy(fact.relation))

    def verify_reasoning_process(self, condition_nodes: List[ReasoningNode], conclusion_node: ReasoningNode) -> bool:
        """
        Verify validity of reasoning process

        Args:
            condition_nodes: List of condition nodes
            conclusion_node: Conclusion node

        Returns:
            bool: True if reasoning process is valid, False otherwise
        """
        if type(conclusion_node.conclusion) != type(conclusion_node.support.conclusion):
            return False
        rule_node = conclusion_node.support
        if not rule_node or not hasattr(rule_node, 'conditions'):
            return False
        values = {}
        for cond_node in condition_nodes:
            if isinstance(cond_node.conclusion, AttributeFact):
                if cond_node.conclusion.entity not in values:
                    values[cond_node.conclusion.entity] = {}
                values[cond_node.conclusion.entity][cond_node.conclusion.attribute] = cond_node.conclusion.expression.value
        solutions = [{}]
        for rule_condition in rule_node.conditions:
            partial_solutions = []
            for cond_node in condition_nodes:
                if ReasoningGraphBuilder.condition_triggered(rule_condition, cond_node.conclusion):
                    if isinstance(rule_condition, AttributeFact):
                        partial_solutions.append(
                            {rule_condition.entity: cond_node.conclusion.entity})
                    elif isinstance(rule_condition, RelationFact):
                        partial_solutions.append({
                            rule_condition.entity1: cond_node.conclusion.entity1,
                            rule_condition.entity2: cond_node.conclusion.entity2
                        })
            if not partial_solutions:
                return False
            old_solutions = solutions
            solutions = []
            for old_solution in old_solutions:
                for partial_solution in partial_solutions:
                    new_solution = {}
                    valid_flag = True
                    for key, value in old_solution.items():
                        if key in partial_solution and partial_solution[key] != value:
                            valid_flag = False
                            continue
                        else:
                            new_solution[key] = value
                    new_solution.update(partial_solution)
                    if valid_flag:
                        solutions.append(new_solution)

        for solution in solutions:
            if isinstance(conclusion_node.conclusion, AttributeFact):
                applyed_expression = rule_node.conclusion.expression.substitute_entity(
                    solution)
                try:
                    compute_args = applyed_expression.parse_compute_args(
                        values)
                    applyed_value = applyed_expression.compute(**compute_args)
                except Exception as e:
                    self.logger.error(
                        f"Error computing expression: {e}, applyed_expression: {applyed_expression}, values: {values}, rule_conclusion: {rule_node.conclusion}, solution: {solution}, solutions: {len(solutions)}, conclusion: {conclusion_node.conclusion}, conditions: {[condition_node.conclusion for condition_node in condition_nodes]}")
                    continue

                if conclusion_node.conclusion.entity == solution[rule_node.conclusion.entity] and \
                   conclusion_node.conclusion.attribute.lower() == rule_node.conclusion.attribute.lower() and \
                   applyed_value == conclusion_node.conclusion.expression.value:
                    return True
            else:
                if (lemmatize_word_spacy(conclusion_node.conclusion.relation) == lemmatize_word_spacy(rule_node.conclusion.relation) and
                        conclusion_node.conclusion.entity1 == solution[rule_node.conclusion.entity1] and
                        conclusion_node.conclusion.entity2 == solution[rule_node.conclusion.entity2]):
                    return True
        return False


class LLMOutputParser:
    """LLM output parser class"""

    def __init__(self, template_factory: TemplateFactory):
        self.template_factory = template_factory
        self.graph_builder = ReasoningGraphBuilder(template_factory)
        self.logger = setup_logger(
            f"{__name__}.{self.__class__.__name__}", level=logging.DEBUG)

    def parse_llm_reasoning(self, llm_output: str, facts: List, rules: List) -> Optional[ReasoningGraph]:
        """
        Parse LLM output reasoning process

        Args:
            llm_output: LLM output text
            facts: Fact list
            rules: Rule list

        Returns:
            ReasoningGraph or None: Parsed reasoning graph
        """
        self.logger.debug("Starting to parse LLM reasoning process")
        try:
            # First try to parse as JSON format
            json_graph = self._try_parse_json_reasoning(
                llm_output, facts, rules)
            if json_graph:
                self.logger.debug(
                    "Successfully parsed JSON format reasoning process")
                return json_graph

            # If not JSON format, parse according to original text format
            if "Reasoning:" in llm_output:
                reasoning_part = llm_output.split("Reasoning:")[1]
                if "Answer:" in reasoning_part:
                    reasoning_part = reasoning_part.split("Answer:")[0]

                reasoning_part = reasoning_part.strip()
                self.logger.debug(f"llm output: {llm_output}")
                self.logger.debug(
                    f"Extracted reasoning part: {reasoning_part}")

                if "=>" in reasoning_part:
                    self.logger.debug(
                        "Found reasoning step identifier '=>', starting to build reasoning graph")
                    graph = self.graph_builder.parse_reasoning_steps_to_graph(
                        reasoning_part, facts, rules)
                    if graph:
                        self.logger.debug("Successfully built reasoning graph")
                    else:
                        self.logger.debug(
                            "Reasoning graph build result is empty")
                    return graph
                else:
                    self.logger.debug(
                        "No '=>' identifier found in reasoning part")

            else:
                self.logger.debug("No 'Reasoning:' part found in LLM output")

            return None
        except Exception as e:
            self.logger.error(f"Error parsing LLM reasoning process: {e}")
            return None

    def _try_parse_json_reasoning(self, llm_output: str, facts: List, rules: List) -> Optional[ReasoningGraph]:
        """
        Try to parse LLM output as JSON format reasoning process

        Args:
            llm_output: LLM output text
            facts: Fact list
            rules: Rule list

        Returns:
            ReasoningGraph or None: Parsed reasoning graph
        """
        try:
            # Try to parse entire output as JSON directly
            json_data = json.loads(llm_output.strip())
            return self._build_graph_from_json(json_data, facts, rules)
        except json.JSONDecodeError:
            try:
                # Try to find JSON block (may be surrounded by other text)
                json_match = re.search(r'\{.*\}', llm_output, re.DOTALL)
                if json_match:
                    json_str = json_match.group(0)
                    json_data = json.loads(json_str)
                    return self._build_graph_from_json(json_data, facts, rules)
            except json.JSONDecodeError:
                pass
        except Exception as e:
            self.logger.debug(f"JSON parsing failed: {e}")

        return None

    def _build_graph_from_json(self, json_data: dict, facts: List, rules: List) -> Optional[ReasoningGraph]:
        """
        Build reasoning graph from JSON data

        Args:
            json_data: JSON format reasoning data
            facts: Fact list
            rules: Rule list

        Returns:
            ReasoningGraph or None: Built reasoning graph
        """
        if "reasoning_steps" not in json_data:
            self.logger.debug("JSON data missing reasoning_steps field")
            return None

        reasoning_steps = json_data["reasoning_steps"]
        if not reasoning_steps:
            self.logger.debug("reasoning_steps is empty")
            return None

        graph = ReasoningGraph()
        intermediate_facts = {}

        self.logger.debug(
            f"Starting to process {len(reasoning_steps)} JSON reasoning steps")

        for step in reasoning_steps:
            step_id = step.get("step_id", "")
            rule_facts = step.get("rule_facts", "")
            conclusion_data = step.get("conclusion", {})

            # Create fact object from conclusion data
            conclusion_fact = self._create_fact_from_json(conclusion_data)
            if not conclusion_fact:
                self.logger.debug(
                    f"Unable to create conclusion fact: {conclusion_data}")
                continue

            conclusion_node = ReasoningNode(conclusion_fact, None)

            # Parse rule_facts string
            conditions = [cond.strip() for cond in rule_facts.split('&')]
            support_rule = None
            condition_nodes = []

            for condition in conditions:
                if condition.startswith('rule_'):
                    rule_index = int(condition.replace('rule_', '')) - 1
                    if 0 <= rule_index < len(rules):
                        support_rule = rules[rule_index]
                elif condition.startswith('fact_'):
                    fact_index = int(condition.replace('fact_', '')) - 1
                    if 0 <= fact_index < len(facts):
                        fact = facts[fact_index]
                        fact_node = ReasoningNode(fact, fact)
                        condition_nodes.append(fact_node)
                elif condition.startswith('int_'):
                    if condition in intermediate_facts:
                        int_node = intermediate_facts[condition]
                        if int_node is not None:
                            condition_nodes.append(int_node)

            conclusion_node.support = support_rule

            # Verify reasoning process
            if support_rule and not self.graph_builder.verify_reasoning_process(condition_nodes, conclusion_node):
                self.logger.debug(f"Invalid reasoning process: {step_id}")
                intermediate_facts[step_id] = None
                continue

            # Add to graph
            graph.add_node(conclusion_node)
            for cond_node in condition_nodes:
                graph.add_node(cond_node)
                graph.add_edge(cond_node, conclusion_node)

            # Save intermediate results
            intermediate_facts[step_id] = conclusion_node

        # Set root node (last valid reasoning step)
        if reasoning_steps:
            last_step_id = reasoning_steps[-1].get("step_id", "")
            if last_step_id in intermediate_facts and intermediate_facts[last_step_id] is not None:
                graph.root = intermediate_facts[last_step_id]

        self.logger.debug(
            f"Reasoning graph built from JSON has {len(graph.nodes)} nodes")
        return graph if graph.nodes else None

    def _create_fact_from_json(self, conclusion_data: dict) -> Union[AttributeFact, RelationFact, None]:
        """
        Create fact object from JSON data

        Args:
            conclusion_data: Conclusion data dictionary

        Returns:
            AttributeFact or RelationFact object, returns None if creation fails
        """
        fact_type = conclusion_data.get("type", "")

        if fact_type == "AttributeFact":
            entity = conclusion_data.get("entity", "")
            attribute = conclusion_data.get("attribute", "")
            value = conclusion_data.get("value", "")

            if entity and attribute and value:
                try:
                    # Try to convert to numeric value
                    numeric_value = int(value)
                    expression = ConstantExpression(numeric_value)
                    return AttributeFact(entity, attribute, expression, self.template_factory)
                except ValueError:
                    # If not numeric, still handle as string
                    self.logger.debug(
                        f"AttributeFact value is not numeric: {value}")
                    return None

        elif fact_type == "RelationFact":
            relation = conclusion_data.get("relation", "")
            entity1 = conclusion_data.get("entity1", "")
            entity2 = conclusion_data.get("entity2", "")

            if relation and entity1 and entity2:
                return RelationFact(relation, entity1, entity2, self.template_factory)

        self.logger.debug(
            f"Unable to create fact, type: {fact_type}, data: {conclusion_data}")
        return None

    @staticmethod
    def extract_llm_answer(llm_output: str) -> Optional[str]:
        """Extract answer from LLM output"""
        logger = setup_logger(f"{__name__}.LLMOutputParser.extract_llm_answer")
        logger.debug("Starting to extract LLM answer")

        # First try to parse JSON format
        try:
            json_data = json.loads(llm_output.strip())
            if "final_answer" in json_data:
                answer = str(json_data["final_answer"]).strip()
                logger.debug(f"Extracted answer from JSON format: {answer}")
                return answer
        except json.JSONDecodeError:
            try:
                # Try to find JSON block
                json_match = re.search(r'\{.*\}', llm_output, re.DOTALL)
                if json_match:
                    json_str = json_match.group(0)
                    json_data = json.loads(json_str)
                    if "final_answer" in json_data:
                        answer = str(json_data["final_answer"]).strip()
                        logger.debug(
                            f"Extracted answer from JSON block: {answer}")
                        return answer
            except json.JSONDecodeError:
                pass
        except Exception as e:
            logger.debug(f"JSON answer extraction failed: {e}")

        # Look for \boxed{...} format answer
        boxed_match = re.search(r'\\boxed\{([^}]+)\}', llm_output)
        if boxed_match:
            answer = boxed_match.group(1).strip()
            logger.debug(f"Extracted answer from boxed format: {answer}")
            return answer

        # Look for content after Answer:
        if "Answer:" in llm_output:
            answer_part = llm_output.split("Answer:")[-1].strip()
            lines = answer_part.split('\n')
            if lines:
                first_line = lines[0].strip()
                number_match = re.search(r'\d+', first_line)
                if number_match:
                    return number_match.group()
                return first_line

        return None

    def extract_llm_facts_and_rules(self, llm_output: str, all_facts: List, all_rules: List) -> Tuple[List, List]:
        """Extract used facts and rules from LLM output"""
        used_facts = []
        used_rules = []

        if "Reasoning:" in llm_output:
            reasoning_part = llm_output.split("Reasoning:")[1]
            if "Answer:" in reasoning_part:
                reasoning_part = reasoning_part.split("Answer:")[0]

            reasoning_part = reasoning_part.strip()
            lines = reasoning_part.strip().split('\n')

            for line in lines:
                if '=>' in line:
                    condition_part = line.split('=>')[0].strip()
                    conditions = [cond.strip()
                                  for cond in condition_part.split('&')]

                    for condition in conditions:
                        if condition.startswith('rule_'):
                            rule_index = int(
                                condition.replace('rule_', '')) - 1
                            if 0 <= rule_index < len(all_rules):
                                rule = all_rules[rule_index]
                                if rule not in used_rules:
                                    used_rules.append(rule)
                        elif condition.startswith('fact_'):
                            fact_index = int(
                                condition.replace('fact_', '')) - 1
                            if 0 <= fact_index < len(all_facts):
                                fact = all_facts[fact_index]
                                if fact not in used_facts:
                                    used_facts.append(fact)

        return used_facts, used_rules


class ReasoningAnalyzer:
    """Reasoning analyzer main class"""

    def __init__(self):
        self.template_factory = TemplateFactory()
        self.fact_parser = FactParser(self.template_factory)
        self.rule_parser = RuleParser(self.template_factory)
        self.graph_builder = ReasoningGraphBuilder(self.template_factory)
        self.llm_parser = LLMOutputParser(self.template_factory)
        self.logger = setup_logger(f"{__name__}.{self.__class__.__name__}")

    def parse_data_to_reasoning_components(self, data: Dict) -> Tuple[List, List, TemplateFactory]:
        """Parse facts and rules from JSON data"""
        self.logger.debug("Starting to parse reasoning component data")

        facts = []
        for i, fact_repr in enumerate(data["facts-repr"]):
            try:
                fact = self.fact_parser.parse_fact_repr(fact_repr)
                facts.append(fact)
                self.logger.debug(
                    f"Successfully parsed fact {i+1}: {fact_repr}")
            except Exception as e:
                self.logger.error(
                    f"Failed to parse fact {i+1}: {fact_repr}, error: {e}")
                raise

        rules = []
        for i, rule_repr in enumerate(data["rules-repr"]):
            try:
                rule = self.rule_parser.parse_rule_repr(rule_repr)
                rules.append(rule)
                self.logger.debug(
                    f"Successfully parsed rule {i+1}: {rule_repr}")
            except Exception as e:
                self.logger.error(
                    f"Failed to parse rule {i+1}: {rule_repr}, error: {e}")
                raise

        self.logger.debug(
            f"Successfully parsed {len(facts)} facts and {len(rules)} rules")
        return facts, rules

    def analyze_reasoning_correctness(self, llm_graph: Optional[ReasoningGraph],
                                      correct_graph: Optional[ReasoningGraph]) -> bool:
        """Analyze reasoning correctness"""
        if not llm_graph:
            return False
        if not correct_graph:
            raise ValueError("Correct reasoning graph cannot be empty")

        if llm_graph.root and correct_graph.root:
            return self.fact_parser.facts_equal_strict(
                llm_graph.root.conclusion, correct_graph.root.conclusion)
        return False

    def process_reasoning_data(self, data: Dict) -> Dict:
        """Complete reasoning data processing function"""
        data_id = data.get("id", "unknown")
        self.logger.debug(f"Starting to process reasoning data, ID: {data_id}")

        # Parse basic components
        facts, rules = self.parse_data_to_reasoning_components(data)

        # Get correct reasoning process
        correct_process = data["reasoning_process_nl"]
        # llm_output = data["llm_output"].split("Facts:\n1.")[0].strip()
        llm_output = data["llm_output"]

        # Build correct reasoning graph
        self.logger.debug("Starting to build correct reasoning graph")
        correct_graph = self.graph_builder.parse_reasoning_steps_to_graph(
            correct_process, facts, rules)
        if correct_graph:
            self.logger.debug(
                f"Successfully built correct reasoning graph, node count: {len(correct_graph.nodes)}, reasoning node count: {len(self.graph_builder.extract_key_reasoning_nodes(correct_graph))}")
        else:
            self.logger.debug("Failed to build correct reasoning graph")

        # Parse LLM reasoning process
        self.logger.debug("Starting to parse LLM reasoning process")
        llm_graph = self.llm_parser.parse_llm_reasoning(
            llm_output, facts, rules)
        if llm_graph:
            self.logger.debug(
                f"Successfully parsed LLM reasoning graph, node count: {len(llm_graph.nodes)}, reasoning node count: {len(self.graph_builder.extract_key_reasoning_nodes(llm_graph))}")
        else:
            self.logger.debug("Failed to parse LLM reasoning graph")

        # Extract LLM answer
        llm_answer = self.llm_parser.extract_llm_answer(llm_output)
        graph_root = self.graph_builder.extract_final_conclusion(llm_graph)
        graph_answer = graph_root.conclusion.expression.value if graph_root and graph_root.conclusion else None
        self.logger.debug(
            f"Extracted LLM answer: {llm_answer}, graph root conclusion: {graph_answer}")
        correct_answer = data['answer']
        correct_graph_answer = correct_graph.root.conclusion.expression.value if correct_graph and correct_graph.root else None
        assert correct_answer == correct_graph_answer, \
            f"Correct answer {correct_answer} does not match reasoning graph root conclusion {correct_graph_answer}"

        # Perform detailed comparison
        self.logger.debug("Starting detailed comparison analysis")
        if graph_answer == correct_graph_answer:
            assert self.analyze_reasoning_correctness(llm_graph, correct_graph), \
                "LLM reasoning graph does not match correct reasoning graph, which is inconsistent with separate answer comparison"
            self.logger.debug(
                "LLM answer matches correct answer, score 1 point")
            score = 1.0
        else:
            llm_intermediate_int = self.graph_builder.extract_key_reasoning_nodes(
                llm_graph)
            correct_intermediate_int = self.graph_builder.extract_key_reasoning_nodes(
                correct_graph)
            cnt = 0
            for correct_int in correct_intermediate_int:
                # Use improved fact comparison method with lemmatization support
                if any(self.fact_parser.facts_equal_strict(correct_int, llm_int) for llm_int in llm_intermediate_int):
                    cnt += 1
            self.logger.debug(
                f"LLM answer does not match correct answer, matched intermediate conclusion count: {cnt}, score {cnt} / {len(correct_intermediate_int)} = {cnt / len(correct_intermediate_int) if correct_intermediate_int else 0:.2f} points")
            score = cnt / len(correct_intermediate_int)
        return score


if __name__ == "__main__":
    import json

    analyzer = ReasoningAnalyzer()
    with open('example.json', 'r', encoding='utf-8') as f:
        data = json.load(f)
    result = analyzer.process_reasoning_data(data)
    print(result)
