import json
import networkx as nx
from collections import Counter
import argparse

def load_data(file_path):
    """Loads the annotated data from a .jsonl file."""
    try:
        with open(file_path, 'r') as f:
            return [json.loads(line) for line in f]
    except FileNotFoundError:
        print(f"Error: Input file not found at {file_path}")
        exit(1)
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {file_path}. Please ensure it's a valid .jsonl file.")
        exit(1)

def build_graph(trace):
    """Builds a directed graph from a trace, including node attributes."""
    dag = nx.DiGraph()
    for i, step in enumerate(trace):
        node_id = step.get('step_index', i)
        
        annotation = step.get('annotation', {})
        node_type = annotation.get('type', step.get('type', 'Unknown'))
        attributes = annotation.get('attributes', {}) # Get all attributes
        
        # Start with the basic type
        attrs = {'type': node_type}
        
        # Add all other attributes from the annotation into the node
        # This is more robust and will automatically include our new ontology terms
        attrs.update(attributes)

        dag.add_node(node_id, **attrs)
        
        parent_index = step.get('trace_dependency', {}).get('dependent_on')
        if parent_index is not None and 0 <= parent_index < len(trace):
            dag.add_edge(parent_index, node_id)
            
    return dag

def get_path_str(dag, path_indices):
    """Constructs a string representation of a path, including key attributes."""
    parts = []
    for node_id in path_indices:
        node_attrs = dag.nodes[node_id]
        node_type = node_attrs.get('type', 'Unknown')
        
        attributes = []
        
        # Universally check for key attributes on any node type.
        if 'premise_grounding' in node_attrs and node_attrs['premise_grounding'] != 'Directly Grounded':
            attributes.append(f"premise={node_attrs['premise_grounding'][:3]}") # Not, Con
        
        if 'information_clarity' in node_attrs and node_attrs['information_clarity'] != 'Clear':
            attributes.append(f"clarity={node_attrs['information_clarity'][:3]}") # Amb, Con
            
        if 'information_quality' in node_attrs and node_attrs['information_quality'] != 'Sufficient':
            attributes.append(f"quality={node_attrs['information_quality'][:3]}") # Ins
        
        if attributes:
            parts.append(f"{node_type}({','.join(attributes)})")
        else:
            parts.append(node_type)
            
    return " -> ".join(parts)

def find_all_paths(dag, trace):
    """Finds all paths from root nodes to leaf nodes in the DAG."""
    roots = [node for node, in_degree in dag.in_degree() if in_degree == 0]
    leaves = [node for node, out_degree in dag.out_degree() if out_degree == 0]
    
    all_paths = []
    for root in roots:
        for leaf in leaves:
            try:
                paths = list(nx.all_simple_paths(dag, source=root, target=leaf))
                for p in paths:
                    path_info = {
                        "node_ids": p,
                        "str": get_path_str(dag, p),
                        "end_node_type": dag.nodes[p[-1]]['type']
                    }
                    all_paths.append(path_info)
            except nx.NetworkXNoPath:
                continue
    return all_paths

def extract_motifs(path_str, n=3):
    """Extracts n-grams (motifs) from a path string."""
    nodes = path_str.split(" -> ")
    return [" -> ".join(nodes[i:i+n]) for i in range(len(nodes) - n + 1)]

def analyze_traces(data, n_gram_sizes=(3, 4, 5)):
    """Analyzes traces for path patterns, transitions, and anomalies."""
    stats = {
        "total_traces": len(data),
        "successful_traces": 0,
        "failed_traces": 0,
        "failed_traces_incorrect_answer": 0,
        "failed_traces_dead_end": 0,
        "successful_paths": Counter(),
        "failed_paths_incorrect_answer": Counter(),
        "failed_paths_dead_end": Counter(),
        "transitions": Counter(),
        "root_node_types": Counter(),
        "info_quality_transitions": Counter(),
        "motif_scores": {n: {} for n in n_gram_sizes},
        "success_motifs": {n: Counter() for n in n_gram_sizes},
        "failure_motifs": {n: Counter() for n in n_gram_sizes}
    }

    all_dags = []
    for record in data:
        trace = record.get('trace', [])
        if not trace:
            continue

        # Determine success based on whether a 'CorrectAnswer' node exists in the trace.
        # This correctly handles cases where the LLM judge has relabeled an 'IncorrectAnswer'.
        is_success = any(step.get('type') == 'CorrectAnswer' for step in trace)
        
        dag = build_graph(trace)
        all_dags.append(dag)
        
        if is_success:
            stats["successful_traces"] += 1
        else:
            stats["failed_traces"] += 1
            # Classify the failed trace based on its outcome
            paths = find_all_paths(dag, trace)
            end_types = {p['end_node_type'] for p in paths}
            if "IncorrectAnswer" in end_types:
                stats["failed_traces_incorrect_answer"] += 1
            else:
                stats["failed_traces_dead_end"] += 1
            
        # Root node analysis
        roots = [node for node, in_degree in dag.in_degree() if in_degree == 0]
        for root_id in roots:
            stats["root_node_types"][dag.nodes[root_id]['type']] += 1

        # Transition analysis (including information and reasoning quality)
        for u, v in dag.edges():
            u_attrs = dag.nodes[u]
            v_attrs = dag.nodes[v]
            
            # Base transition
            transition = f"{u_attrs['type']} -> {v_attrs['type']}"
            stats["transitions"][transition] += 1

            # Transition with quality for search_result
            if u_attrs['type'] == 'search_result':
                quality = u_attrs.get('information_quality', 'N/A')
                u_type_str_for_quality = f"search_result({quality})"
                quality_transition = f"{u_type_str_for_quality} -> {v_attrs['type']}"
                stats["info_quality_transitions"][quality_transition] += 1
            
            # This check is legacy; reasoning_quality is not central to the new report.
            if u_attrs['type'] in ['InformationSynthesis']:
                quality = u_attrs.get('reasoning_quality', 'N/A')
                u_type_str_for_quality = f"{u_attrs['type']}({quality})"
                quality_transition = f"{u_type_str_for_quality} -> {v_attrs['type']}"


        # Path analysis
        paths = find_all_paths(dag, trace)
        
        for p in paths:
            path_str = p["str"]
            end_type = p["end_node_type"]
            
            for n in n_gram_sizes:
                motifs = extract_motifs(path_str, n=n)

                if end_type == "CorrectAnswer":
                    if n == min(n_gram_sizes): # Only count paths once
                        stats["successful_paths"][path_str] += 1
                    for m in motifs:
                        stats["success_motifs"][n][m] += 1
                elif end_type == "IncorrectAnswer":
                    if n == min(n_gram_sizes): # Only count paths once
                        stats["failed_paths_incorrect_answer"][path_str] += 1
                    for m in motifs:
                        stats["failure_motifs"][n][m] += 1
                else: # Dead-end paths
                    if n == min(n_gram_sizes): # Only count paths once
                        stats["failed_paths_dead_end"][path_str] += 1
                    for m in motifs:
                        stats["failure_motifs"][n][m] += 1

    # Calculate motif correlation scores for each n-gram size
    for n in n_gram_sizes:
        all_motifs = set(stats["success_motifs"][n].keys()) | set(stats["failure_motifs"][n].keys())
        motif_scores = {}
        total_success_motifs = sum(stats["success_motifs"][n].values())
        total_failure_motifs = sum(stats["failure_motifs"][n].values())

        for motif in all_motifs:
            success_freq = stats["success_motifs"][n].get(motif, 0) / total_success_motifs if total_success_motifs else 0
            failure_freq = stats["failure_motifs"][n].get(motif, 0) / total_failure_motifs if total_failure_motifs else 0
            
            # Simple correlation score: difference in relative frequencies
            motif_scores[motif] = success_freq - failure_freq
            
        stats["motif_scores"][n] = motif_scores
        
    # --- New: Divergence Analysis ---
    all_failed_paths = stats["failed_paths_incorrect_answer"] + stats["failed_paths_dead_end"]
    stats["divergence_points"] = analyze_divergence_points(
        stats["successful_paths"],
        all_failed_paths
    )
    
    # --- New: Knowledge-Work Failure Analysis ---
    stats["knowledge_failures"] = analyze_knowledge_failures(all_dags)
    
    return stats

def analyze_divergence_points(successful_paths, failed_paths):
    """
    Finds and analyzes points where similar paths diverge to different outcomes.
    """
    path_prefixes = {} # prefix -> {success: count, fail: count}

    # Process successful paths
    for path_str, count in successful_paths.items():
        nodes = path_str.split(" -> ")
        for i in range(1, len(nodes)):
            prefix = " -> ".join(nodes[:i])
            if prefix not in path_prefixes:
                path_prefixes[prefix] = Counter()
            path_prefixes[prefix]["success"] += count

    # Process failed paths
    for path_str, count in failed_paths.items():
        nodes = path_str.split(" -> ")
        for i in range(1, len(nodes)):
            prefix = " -> ".join(nodes[:i])
            if prefix not in path_prefixes:
                path_prefixes[prefix] = Counter()
            path_prefixes[prefix]["fail"] += count

    # Analyze divergences
    divergences = []
    for prefix, outcomes in path_prefixes.items():
        total = outcomes["success"] + outcomes["fail"]
        # Consider prefixes that lead to both outcomes and appear reasonably often
        if outcomes["success"] > 0 and outcomes["fail"] > 0 and total > 2:
             # Find the full paths to see what comes next
            next_step_success = Counter()
            for path_str, count in successful_paths.items():
                if path_str.startswith(prefix + " -> "):
                    nodes = path_str.split(" -> ")
                    prefix_len = len(prefix.split(" -> "))
                    if len(nodes) > prefix_len:
                        next_step_success[nodes[prefix_len]] += count
            
            next_step_fail = Counter()
            for path_str, count in failed_paths.items():
                 if path_str.startswith(prefix + " -> "):
                    nodes = path_str.split(" -> ")
                    prefix_len = len(prefix.split(" -> "))
                    if len(nodes) > prefix_len:
                        next_step_fail[nodes[prefix_len]] += count
            
            divergences.append({
                "prefix": prefix,
                "total_occurrences": total,
                "success_count": outcomes["success"],
                "fail_count": outcomes["fail"],
                "next_step_success": next_step_success.most_common(),
                "next_step_fail": next_step_fail.most_common()
            })

    # Sort by how frequently the divergence occurs
    return sorted(divergences, key=lambda x: x["total_occurrences"], reverse=True)

def analyze_knowledge_failures(analyzed_traces_dags):
    """
    Analyzes the traces to quantify key 'Glass Box' failure modes.
    """
    knowledge_stats = {
        "ungrounded_leaps": 0,
        "ambiguity_traps": 0,
        "unsuccessful_searches": 0,
        "ambiguity_trap_outcomes": Counter(),
    }

    for dag in analyzed_traces_dags:
        # Determine if the entire trace is considered a failure
        is_failure = not any(node_attrs.get('type') == 'CorrectAnswer' for _, node_attrs in dag.nodes(data=True))

        for node_id, attrs in dag.nodes(data=True):
            # 1. Analyze "Ungrounded Leaps"
            if attrs.get('premise_grounding') == 'Not Grounded':
                knowledge_stats["ungrounded_leaps"] += 1
            
            # 2. Analyze "Ambiguity Traps"
            if attrs.get('type') == 'search_result' and attrs.get('information_clarity') == 'Ambiguous':
                knowledge_stats["ambiguity_traps"] += 1
                if is_failure:
                    # Find out what the agent did immediately after receiving ambiguous info in a failed trace
                    successors = list(dag.successors(node_id))
                    if successors:
                        # Assuming one primary successor for this analysis
                        next_step_type = dag.nodes[successors[0]]['type']
                        knowledge_stats["ambiguity_trap_outcomes"][next_step_type] += 1

            # 3. Analyze "Unsuccessful Searches"
            if attrs.get('type') == 'search_result' and attrs.get('information_quality') == 'Insufficient':
                knowledge_stats["unsuccessful_searches"] += 1

    return knowledge_stats


def print_report(stats, top_n):
    """Prints a structured analysis report focused on insights."""
    total = stats['total_traces']
    success = stats['successful_traces']
    fail = stats['failed_traces']
    success_rate = (success / total * 100) if total > 0 else 0

    print("--- Error Pattern Analysis Report ---")
    print("\n--- I. Overall Performance ---")
    print(f"Total Traces Analyzed: {total}")
    print(f"Successful Traces: {success} ({success_rate:.2f}%)")
    print(f"Failed Traces: {fail}")
    incorrect_traces = stats.get("failed_traces_incorrect_answer", 0)
    dead_end_traces = stats.get("failed_traces_dead_end", 0)
    print(f"  - Concluded with Incorrect Answer: {incorrect_traces} traces")
    print(f"  - Concluded with a Dead End: {dead_end_traces} traces")


    print("\n\n--- II. Key Behavioral Insights (Motif Analysis) ---")
    print("This section highlights the most significant patterns (motifs) of varying lengths that correlate with success or failure.")

    n_gram_sizes = sorted(stats["motif_scores"].keys())
    
    for n in n_gram_sizes:
        sorted_motifs = sorted(stats["motif_scores"][n].items(), key=lambda item: item[1], reverse=True)
        
        # Filter out motifs with a score of 0
        sorted_motifs = [m for m in sorted_motifs if m[1] != 0]

        if not sorted_motifs:
            continue

        print(f"\n--- Top {top_n} {n}-Step MOTIFS CORRELATED WITH SUCCESS ---")
        print(f"(These {n}-step patterns are more likely to appear in successful traces.)")
        for motif, score in sorted_motifs[:top_n]:
            print(f"Score: {score:+.4f} | Motif: {motif}")

        print(f"\n--- Top {top_n} {n}-Step MOTIFS CORRELATED WITH FAILURE ---")
        print(f"(These {n}-step patterns are more likely to appear in failed traces, especially highlighting logical errors.)")
        # To get the top N failure motifs, we take the last N from the sorted list and reverse them
        failure_motifs = sorted_motifs[-top_n:]
        for motif, score in reversed(failure_motifs):
            print(f"Score: {score:+.4f} | Motif: {motif}")


    print("\n\n--- III. Detailed Anomaly & Pattern Analysis ---")
    print("(This section provides detailed statistics for deeper investigation.)")

    print("\n1. Information Quality Impact Analysis:")
    if not stats["info_quality_transitions"]:
        print("  No transitions involving information quality were found.")
    else:
        for transition, count in stats["info_quality_transitions"].most_common(top_n):
            print(f"- {transition}: {count} occurrences")

    print("\n2. Common Step Transitions (Top N Bigrams):")
    for transition, count in stats["transitions"].most_common(top_n):
        print(f"- {transition}: {count} occurrences")

    print("\n3. Trace Starting Points (Root Node Types):")
    for node_type, count in stats["root_node_types"].most_common():
        print(f"- Started with '{node_type}': {count} times")


    print("\n\n--- IV. Appendix: Full Path Examples ---")
    
    print(f"\n--- Top {top_n} Most Common SUCCESSFUL Paths (Ending in CorrectAnswer) ---")
    if not stats["successful_paths"]:
        print("No successful paths found.")
    else:
        for path, count in stats["successful_paths"].most_common(top_n):
            print(f"Count: {count:<5} Path: {path}")

    print(f"\n--- Top {top_n} Most Common FAILED Paths (Ending in IncorrectAnswer) ---")
    if not stats["failed_paths_incorrect_answer"]:
        print("No paths ending in an incorrect answer were found.")
    else:
        for path, count in stats["failed_paths_incorrect_answer"].most_common(top_n):
            print(f"Count: {count:<5} Path: {path}")
            
    print(f"\n--- Top {top_n} Most Common DEAD-END FAILED Paths (Never Reaching an Answer) ---")
    if not stats["failed_paths_dead_end"]:
        print("No dead-end paths were found.")
    else:
        for path, count in stats["failed_paths_dead_end"].most_common(top_n):
            print(f"Count: {count:<5} Path: {path}")


    # --- New: Divergence Points Report ---
    print("\n\n--- V. Key Divergence Points ---")
    print("This section highlights critical points where similar paths diverge towards success or failure.")
    if not stats["divergence_points"]:
        print("\nNo significant divergence points found.")
    else:
        for i, div in enumerate(stats["divergence_points"][:top_n]):
            success_rate = (div['success_count'] / div['total_occurrences']) * 100
            fail_rate = (div['fail_count'] / div['total_occurrences']) * 100
            print(f"\n--- Divergence Point #{i+1} (Occurs {div['total_occurrences']} times) ---")
            print(f"  Prefix: {div['prefix']}")
            print(f"  Outcome Split: {success_rate:.1f}% SUCCESS vs. {fail_rate:.1f}% FAILURE")
            
            print("  Next steps leading to SUCCESS:")
            if div['next_step_success']:
                for step, count in div['next_step_success']:
                    print(f"    - {step} ({count} times)")
            else:
                print("    - (End of path)")

            print("  Next steps leading to FAILURE:")
            if div['next_step_fail']:
                for step, count in div['next_step_fail']:
                    print(f"    - {step} ({count} times)")
            else:
                print("    - (End of path)")

    # --- New: Knowledge-Work Failure Analysis ---
    print("\n\n--- VI. Knowledge-Work Failure Analysis ---")
    print("This section quantifies failures related to the agent's 'knowing' abilities, based on our Taxonomy of Failure.")
    
    k_stats = stats.get("knowledge_failures", {})
    if not k_stats:
        print("\nNo knowledge-work failure data available.")
        return

    # 1. Report on Ungrounded Leaps
    total_leaps = k_stats.get('ungrounded_leaps', 0)
    print(f"\n1. Ungrounded Leaps (Acting on Not Grounded Premises):")
    print(f"  - Total Detected: {total_leaps} instances of 'premise_grounding: Not Grounded'.")
    
    # 2. Report on Ambiguity Traps
    total_ambiguity = k_stats.get('ambiguity_traps', 0)
    print(f"\n2. Ambiguity Traps (Receiving Ambiguous Search Results):")
    print(f"  - Total Detected: {total_ambiguity} instances of 'information_clarity: Ambiguous'.")
    if total_ambiguity > 0 and stats['failed_traces'] > 0:
        print("  - Subsequent actions in failed traces:")
        if not k_stats.get('ambiguity_trap_outcomes'):
            print("    - No subsequent actions recorded in failed traces.")
        for outcome, count in k_stats.get('ambiguity_trap_outcomes', {}).items():
            print(f"    - In {count} failed traces, ambiguity was followed by a '{outcome}' step.")
            
    # 3. Report on Unsuccessful Searches
    total_unsuccessful = k_stats.get('unsuccessful_searches', 0)
    print(f"\n3. Unsuccessful Searches (Receiving Insufficient Information):")
    print(f"  - Total Detected: {total_unsuccessful} instances of 'information_quality: Insufficient'.")


def main():
    parser = argparse.ArgumentParser(description="Analyze error patterns from structured logs.")
    parser.add_argument("input_file", help="Path to the .jsonl file with annotated traces.")
    parser.add_argument("-o", "--output_file", help="Path to save the analysis report.", default=None)
    parser.add_argument("-n", "--top_n", type=int, default=10, help="Number of top items to display in lists.")
    args = parser.parse_args()

    data = load_data(args.input_file)
    if data:
        stats = analyze_traces(data)
        
        if args.output_file:
            # Redirect stdout to the file
            with open(args.output_file, 'w') as f:
                import sys
                original_stdout = sys.stdout
                sys.stdout = f
                try:
                    print_report(stats, args.top_n)
                finally:
                    sys.stdout = original_stdout
            print(f"Analysis report saved to {args.output_file}")
        else:
            print_report(stats, args.top_n)

if __name__ == "__main__":
    main() 