#!/usr/bin/env python3
import os
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
from pathlib import Path


# Import the automata register (assuming the import path is correct)
try:
    from intervention_sampling.automata_register import AUTOMATA_REGISTER
except ImportError:
    print("Warning: Could not import AUTOMATA_REGISTER. Will use state/symbol maps from files if provided.")
    AUTOMATA_REGISTER = {}


def set_plot_style():
    """Configure plot style to use Times New Roman and other aesthetics."""
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman'],
        'font.size': 12,
        'axes.titlesize': 14,
        'axes.labelsize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 16
    })
    sns.set_style("ticks")


def load_target_kl_data(base_dir, automaton, intervention_type, target, architecture, num_seeds=10,
                      intervention_start=50, intervention_end=2000, intervention_step=100):
    """Load KL data for a specific target from multiple seeds and calculate statistics."""
    
    # Store aggregated data
    aggregated_data = {
        'intervention_counts': [],
        'target_kl': {'mean': [], 'std': [], 'values': []},
        'total_kl': {'mean': [], 'std': [], 'values': []}
    }
    
    # For each intervention count, collect data from all seeds
    for i in range(intervention_start, intervention_end + 1, intervention_step):
        # Data for this intervention count across all seeds
        intervention_data = {
            'target_kl': [],
            'total_kl': []
        }
        
        valid_seed_count = 0
        
        # Collect data from all seeds
        for seed in range(1, num_seeds + 1):
            # Get the correct path to the evaluation file
            eval_dir = os.path.join(
                base_dir, 
                automaton,
                intervention_type, 
                f"target_{target}", 
                f"seed_{seed}", 
                f"{architecture}_{i}", 
                "eval"
            )
            
            kl_file = os.path.join(eval_dir, "decomposed_kls.json")
            
            if os.path.exists(kl_file):
                with open(kl_file, 'r') as f:
                    data = json.load(f)
                
                valid_seed_count += 1
                
                # Store total KL
                intervention_data['total_kl'].append(data['total_kl'])
                
                # Store target KL based on intervention type
                if intervention_type == 'state':
                    target_key = str(target)
                    if target_key in data['state_contributions']:
                        intervention_data['target_kl'].append(data['state_contributions'][target_key])
                    else:
                        print(f"Warning: Target state {target} not found in data for seed {seed}, count {i}")
                        
                elif intervention_type == 'symbol':
                    target_key = str(target)
                    if target_key in data['symbol_contributions']:
                        intervention_data['target_kl'].append(data['symbol_contributions'][target_key])
                    else:
                        print(f"Warning: Target symbol {target} not found in data for seed {seed}, count {i}")
                        
                elif intervention_type == 'arc' or intervention_type == 'transition':
                    # For transitions we need to find the right key (this is simplified)
                    found = False
                    for transition_key in data['transition_contributions']:
                        # Look for the target in the transition key
                        if str(target) in transition_key:
                            intervention_data['target_kl'].append(data['transition_contributions'][transition_key])
                            found = True
                            break
                    
                    if not found:
                        print(f"Warning: Target transition with index {target} not found for seed {seed}, count {i}")
                
            else:
                # print(f"Warning: No KL data found for {architecture}, intervention count {i}, seed {seed}")
                # print(f"Missing file: {kl_file}")
                pass
        
        # Only add this intervention count if we have data from at least one seed
        if valid_seed_count > 0 and intervention_data['target_kl']:
            aggregated_data['intervention_counts'].append(i)
            
            # Calculate statistics for target KL
            target_kl_values = intervention_data['target_kl']
            aggregated_data['target_kl']['values'].append(target_kl_values)
            aggregated_data['target_kl']['mean'].append(np.mean(target_kl_values))
            aggregated_data['target_kl']['std'].append(np.std(target_kl_values, ddof=1) if len(target_kl_values) > 1 else 0)
            
            # Calculate statistics for total KL
            total_kl_values = intervention_data['total_kl']
            aggregated_data['total_kl']['values'].append(total_kl_values)
            aggregated_data['total_kl']['mean'].append(np.mean(total_kl_values))
            aggregated_data['total_kl']['std'].append(np.std(total_kl_values, ddof=1) if len(total_kl_values) > 1 else 0)
    
    return aggregated_data


def parse_arcs_file(file_path, at_least_once=False):
    """Parse an arcs.txt file and count state, symbol, and transition occurrences."""
    try:
        with open(file_path, 'r') as f:
            content = f.read().strip()
        
        # Counts for states, symbols, and transitions
        state_counts = Counter()
        symbol_counts = Counter()
        transition_counts = Counter()
        
        # Process each line
        for line in content.split('\n'):
            if not line.strip() or not line.startswith('['):
                continue
            seen_states = set()
            seen_symbols = set()
            seen_transitions = set()

            # Extract tuples from the line
            # Format is typically: [(src_state, tgt_state, symbol), ...]
            # Clean the line and evaluate as Python literal
            clean_line = line.strip()
            # Only handle lines that have the right format
            if clean_line.startswith('[') and clean_line.endswith(']'):
                # Safely evaluate the line as a Python literal
                arcs = eval(clean_line)
                
                # Count occurrences
                for src_state, tgt_state, symbol in arcs:
                    transition_key = f"{src_state}-{symbol}-s{tgt_state}"

                    if not at_least_once:
                        # Count source states
                        state_counts[str(src_state)] += 1
                        # Count symbols
                        symbol_counts[str(symbol)] += 1
                        # Count transitions - format as "src_state-symbol-tgt_state" for consistency
                        transition_counts[transition_key] += 1
                    else:
                        if src_state not in seen_states:
                            state_counts[str(src_state)] += 1
                        if symbol not in seen_symbols:
                            symbol_counts[str(symbol)] += 1
                        if transition_key not in seen_transitions:
                            transition_counts[transition_key] += 1

                    seen_states.add(src_state)
                    seen_symbols.add(symbol)
                    seen_transitions.add(transition_key)
        
        return state_counts, symbol_counts, transition_counts
    
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
        return Counter(), Counter(), Counter()


def find_arcs_files(data_dir, automaton, seed=None):
    """Find all arcs.txt files for a given automaton in the data directory."""
    # Base path to the datasets
    base_path = os.path.join(data_dir, "datasets", automaton, "vanilla", "0")
    
    # If seed is specified, look in that specific seed directory
    if seed is not None:
        base_path = os.path.join(base_path, str(seed))
    
    # Find all arcs.txt files recursively
    arcs_files = []
    for root, dirs, files in os.walk(base_path):
        if "train" not in root:
            continue
        if "arcs.txt" in files:
            arcs_files.append(os.path.join(root, "arcs.txt"))
    
    return arcs_files


def load_vanilla_data_with_occurrences(model_dir, data_dir, automaton, architecture, intervention_type, 
                                       targets, seeds=range(1, 11), run_ids=None, at_least_once=False):
    """
    Load KL data from vanilla runs with natural occurrence counts from arcs.txt files.
    
    Returns a dictionary with targets as keys, each containing:
    - 'occurrence_counts': List of natural occurrence counts
    - 'kl_values': List of corresponding KL values
    """
    # Default run IDs if not provided
    if run_ids is None:
        run_ids = [50, 150, 250, 350, 450, 550, 650, 750, 850, 950, 1050, 1150, 1250, 1350, 1450, 1550, 1650, 1750, 1850, 1950]
    
    # Initialize results dictionary
    results = {target: {'occurrence_counts': [], 'kl_values': [], 'total_kl_values': []} for target in targets}
    
    # For each seed and run ID, collect data and counts
    for seed in seeds:
        # Find arcs.txt files for this seed
        arcs_files = find_arcs_files(data_dir, automaton, seed)
        for file_idx, arcs_file in enumerate(arcs_files):
            # Get run_id to use (if we have enough, otherwise loop back)
            run_id = run_ids[file_idx % len(run_ids)] if run_ids else 50
            
            # Parse the arcs file to get counts
            state_counts, symbol_counts, transition_counts = parse_arcs_file(arcs_file, at_least_once=at_least_once)
            # Select the appropriate counts based on intervention type
            if intervention_type == 'state':
                counts = state_counts
            elif intervention_type == 'symbol':
                counts = symbol_counts
            else:  # transition or arc
                counts = transition_counts
            
            # Get the KL decomposition from vanilla run
            eval_dir = os.path.join(
                model_dir, 
                automaton,
                "vanilla", 
                "target_0", 
                f"seed_{seed}", 
                f"{architecture}_{run_id}", 
                "eval"
            )
            
            kl_file = os.path.join(eval_dir, "decomposed_kls.json")
            
            if os.path.exists(kl_file):
                with open(kl_file, 'r') as f:
                    data = json.load(f)
                
                # Get the appropriate KL contributions
                if intervention_type == 'state':
                    kl_contributions = data.get('state_contributions', {})
                elif intervention_type == 'symbol':
                    kl_contributions = data.get('symbol_contributions', {})
                else:  # transition or arc
                    kl_contributions = data.get('transition_contributions', {})
                
                # Get total KL
                total_kl = data.get('total_kl', 0)
                
                # For each target, store count and KL value if available
                for target in targets:
                    target_key = str(target)
                    
                    # Handle transitions specially as they need key matching
                    if intervention_type in ['arc', 'transition']:
                        # Find the right transition key
                        found = False
                        for trans_key in kl_contributions:
                            if target_key in trans_key:
                                target_key = trans_key
                                found = True
                                break
                        
                        if not found:
                            continue
                    
                    # Get the count and KL value
                    count = counts.get(target_key, 0)
                    kl_value = kl_contributions.get(target_key, 0)
                    
                    # Only add if we have both count and KL value
                    if count > 0 and kl_value > 0:
                        results[target]['occurrence_counts'].append(count)
                        results[target]['kl_values'].append(kl_value)
                        results[target]['total_kl_values'].append(total_kl)
    return results

def plot_target_comparison(targets, architectures, intervention_data, vanilla_data, 
                          intervention_type, output_dir, automaton, state_map=None, symbol_map=None):
    """
    Plot comparison of KL trends: intervention counts vs natural occurrences.
    For each target and architecture, show two lines:
    1. KL vs intervention count
    2. KL vs natural occurrence count
    """
    # Determine subplot layout based on number of targets
    n_targets = len(targets)
    
    if n_targets <= 3:
        n_cols = n_targets
        n_rows = 1
    else:
        n_cols = 3
        n_rows = (n_targets + 2) // 3  # Ceiling division
    
    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    
    # Handle different shapes of axes array
    if n_rows == 1 and n_cols == 1:
        axes = np.array([axes])
    elif n_rows == 1 or n_cols == 1:
        axes = np.array(axes).reshape(-1)
    
    # Set color scheme for architectures
    arch_colors = {
        'transformer': '#1f77b4',  # blue
        'lstm': '#ff7f0e',         # orange
        'rnn': '#2ca02c'           # green
    }
    
    # Set line styles for data types
    line_styles = {
        'intervention': '-',      # solid for intervention
        'natural': '--'           # dashed for natural occurrence
    }
    
    # Set marker styles for data types
    markers = {
        'intervention': 'o',      # circle for intervention
        'natural': 'x'            # x for natural occurrence
    }
    
    # Determine global y-axis limits for consistent scaling
    y_min = float('inf')
    y_max = float('-inf')
    
    # First pass to find global y limits
    for t_idx, target in enumerate(sorted(targets)):
        # Check intervention data
        for arch in architectures:
            if arch in intervention_data and target in intervention_data[arch]:
                arch_data = intervention_data[arch][target]
                if not arch_data['intervention_counts']:
                    continue
                
                # Calculate standard error
                sem = [std / np.sqrt(len(vals)) 
                       for std, vals in zip(arch_data['target_kl']['std'], 
                                           arch_data['target_kl']['values'])]
                
                y = arch_data['target_kl']['mean']
                min_val = min([v - e for v, e in zip(y, sem)])
                max_val = max([v + e for v, e in zip(y, sem)])
                
                y_min = min(y_min, min_val)
                y_max = max(y_max, max_val)
        
        # Check vanilla data
        for arch in architectures:
            if arch in vanilla_data and target in vanilla_data[arch]:
                vanilla_target_data = vanilla_data[arch][target]
                if not vanilla_target_data['kl_values']:
                    continue
                
                y = vanilla_target_data['kl_values']
                min_val = min(y) if y else 0
                max_val = max(y) if y else 0
                
                y_min = min(y_min, min_val)
                y_max = max(y_max, max_val)
    
    # Add buffer to y-axis limits
    y_range = y_max - y_min
    y_min = max(0, y_min - 0.05 * y_range)  # Ensure we don't go below 0
    y_max = y_max + 0.05 * y_range
    
    # Second pass to plot data
    for t_idx, target in enumerate(sorted(targets)):
        ax = axes.flat[t_idx]
        
        # Determine title based on intervention type
        if intervention_type == 'state':
            if state_map and int(target) in state_map:
                title = f'State {target}: {state_map[int(target)]}'
            else:
                title = f'State {target}'
        elif intervention_type == 'symbol':
            target_int = int(target) if target.isdigit() else target
            if symbol_map and target_int in symbol_map:
                title = f'Symbol {target}: {symbol_map[target_int]}'
            else:
                title = f'Symbol {target}'
        else:  # transition or arc
            title = f'Transition {target}'
        
        ax.set_title(title)
        ax.set_xlabel('Count (Intervention or Natural)')
        ax.set_ylabel('KL Contribution')
        ax.set_ylim(y_min, y_max)
        
        # Process each architecture
        for arch in architectures:
            # 1. Plot intervention data (controlled)
            if arch in intervention_data and target in intervention_data[arch]:
                arch_data = intervention_data[arch][target]
                
                if arch_data['intervention_counts']:
                    x = arch_data['intervention_counts']
                    y = arch_data['target_kl']['mean']
                    yerr = arch_data['target_kl']['std']
                    
                    # Calculate standard error of the mean
                    sem = [std / np.sqrt(len(vals)) 
                           for std, vals in zip(yerr, arch_data['target_kl']['values'])]
                    
                    # Plot the intervention line
                    line = ax.plot(x, y, 
                                  marker=markers['intervention'], 
                                  linestyle=line_styles['intervention'],
                                  linewidth=2, 
                                  color=arch_colors[arch],
                                  label=f"{arch.capitalize()} (Intervention)")
                    
                    # Add shaded error region
                    ax.fill_between(
                        x, 
                        [m - s for m, s in zip(y, sem)], 
                        [m + s for m, s in zip(y, sem)], 
                        alpha=0.15, 
                        color=arch_colors[arch]
                    )
            
            # 2. Plot vanilla data (natural occurrences)
            if arch in vanilla_data and target in vanilla_data[arch]:
                vanilla_target_data = vanilla_data[arch][target]
                
                if vanilla_target_data['kl_values'] and vanilla_target_data['occurrence_counts']:
                    # Calculate average KL value for each unique occurrence count
                    unique_counts = sorted(set(vanilla_target_data['occurrence_counts']))
                    avg_kl_values = []
                    
                    for count in unique_counts:
                        # Find all KL values for this occurrence count
                        indices = [i for i, c in enumerate(vanilla_target_data['occurrence_counts']) if c == count]
                        kl_values = [vanilla_target_data['kl_values'][i] for i in indices]
                        avg_kl_values.append(np.mean(kl_values))
                    
                    # Plot the natural occurrence line
                    ax.plot(unique_counts, avg_kl_values,
                           marker=markers['natural'],
                           linestyle=line_styles['natural'],
                           linewidth=2,
                           color=arch_colors[arch],
                           label=f"{arch.capitalize()} (Natural)")
        
        ax.grid(True, linestyle='--', alpha=0.7)
        sns.despine(ax=ax)
    
    # Hide unused subplots
    for j in range(t_idx+1, len(axes.flat)):
        axes.flat[j].set_visible(False)
    
    # Add a single legend for the entire figure
    handles, labels = [], []
    for arch in architectures:
        # Add intervention line
        handles.append(plt.Line2D([0], [0], 
                                 color=arch_colors[arch], 
                                 linestyle=line_styles['intervention'],
                                 marker=markers['intervention'], 
                                 linewidth=2,
                                 markersize=6,
                                 label=f"{arch.capitalize()} (Intervention)"))
        labels.append(f"{arch.capitalize()} (Intervention)")
        
        # Add natural occurrence line
        handles.append(plt.Line2D([0], [0], 
                                 color=arch_colors[arch], 
                                 linestyle=line_styles['natural'],
                                 marker=markers['natural'], 
                                 linewidth=2,
                                 markersize=6,
                                 label=f"{arch.capitalize()} (Natural)"))
        labels.append(f"{arch.capitalize()} (Natural)")
    
    # Place legend outside the figure to avoid being cut off
    if n_targets > 3:
        fig.legend(handles, labels, loc='upper center', 
                  bbox_to_anchor=(0.5, 0), 
                  fancybox=True, shadow=True, ncol=min(3, len(handles)))
        plt.tight_layout(rect=[0, 0.1, 1, 0.95])  # Leave more space at bottom for legend
    else:
        fig.legend(handles, labels, loc='center right', 
                  bbox_to_anchor=(1.05, 0.5), 
                  fancybox=True, shadow=True)
        plt.tight_layout(rect=[0, 0, 0.85, 0.95])  # Leave more space at right for legend
    
    # Add title
    fig.suptitle(f'{automaton.title()} - {intervention_type.title()} - Intervention vs Natural KL')
    
    # Save figure
    plt.savefig(os.path.join(output_dir, f'{intervention_type}_intervention_vs_natural.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, f'{intervention_type}_intervention_vs_natural.pdf'))
    plt.close()


def plot_per_architecture_comparison(targets, architecture, intervention_data, vanilla_data, 
                                  intervention_type, output_dir, automaton, state_map=None, symbol_map=None):
    """
    Create separate plots for each architecture showing both intervention and natural occurrence trends.
    """
    # Determine subplot layout based on number of targets
    n_targets = len(targets)
    
    if n_targets <= 3:
        n_cols = n_targets
        n_rows = 1
    else:
        n_cols = 3
        n_rows = (n_targets + 2) // 3  # Ceiling division
    
    # Create figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 4*n_rows))
    
    # Handle different shapes of axes array
    if n_rows == 1 and n_cols == 1:
        axes = np.array([axes])
    elif n_rows == 1 or n_cols == 1:
        axes = np.array(axes).reshape(-1)
    
    # Set colors and styles
    intervention_color = '#1f77b4'  # blue for intervention
    natural_color = '#d62728'       # red for natural
    
    intervention_style = '-'        # solid for intervention
    natural_style = '--'            # dashed for natural
    
    intervention_marker = 'o'       # circle for intervention
    natural_marker = 'x'            # x for natural
    
    # Determine global y-axis limits for consistent scaling
    y_min = float('inf')
    y_max = float('-inf')
    
    # First pass to find global y limits
    for t_idx, target in enumerate(sorted(targets)):
        # Check intervention data
        if target in intervention_data.get(architecture, {}):
            arch_data = intervention_data[architecture][target]
            if arch_data['intervention_counts']:
                sem = [std / np.sqrt(len(vals)) 
                       for std, vals in zip(arch_data['target_kl']['std'], 
                                           arch_data['target_kl']['values'])]
                
                y = arch_data['target_kl']['mean']
                min_val = min([v - e for v, e in zip(y, sem)])
                max_val = max([v + e for v, e in zip(y, sem)])
                
                y_min = min(y_min, min_val)
                y_max = max(y_max, max_val)
        
        # Check vanilla data
        if target in vanilla_data.get(architecture, {}):
            vanilla_target_data = vanilla_data[architecture][target]
            if vanilla_target_data['kl_values']:
                y = vanilla_target_data['kl_values']
                min_val = min(y) if y else 0
                max_val = max(y) if y else 0
                
                y_min = min(y_min, min_val)
                y_max = max(y_max, max_val)
    
    # Add buffer to y-axis limits
    y_range = y_max - y_min
    y_min = max(0, y_min - 0.05 * y_range)  # Ensure we don't go below 0
    y_max = y_max + 0.05 * y_range
    
    # Second pass to plot data
    for t_idx, target in enumerate(sorted(targets)):
        ax = axes.flat[t_idx]
        
        # Determine title based on intervention type
        if intervention_type == 'state':
            if state_map and int(target) in state_map:
                title = f'State {target}: {state_map[int(target)]}'
            else:
                title = f'State {target}'
        elif intervention_type == 'symbol':
            target_int = int(target) if target.isdigit() else target
            if symbol_map and target_int in symbol_map:
                title = f'Symbol {target}: {symbol_map[target_int]}'
            else:
                title = f'Symbol {target}'
        else:  # transition or arc
            title = f'Transition {target}'
        
        ax.set_title(title)
        ax.set_xlabel('Count (Intervention or Natural)')
        ax.set_ylabel('KL Contribution')
        ax.set_ylim(y_min, y_max)
        
        # 1. Plot intervention data
        if target in intervention_data.get(architecture, {}):
            arch_data = intervention_data[architecture][target]
            
            if arch_data['intervention_counts']:
                x = arch_data['intervention_counts']
                y = arch_data['target_kl']['mean']
                yerr = arch_data['target_kl']['std']
                
                # Calculate standard error of the mean
                sem = [std / np.sqrt(len(vals)) 
                       for std, vals in zip(yerr, arch_data['target_kl']['values'])]
                
                # Plot intervention line
                line = ax.plot(x, y, 
                              marker=intervention_marker, 
                              linestyle=intervention_style,
                              linewidth=2, 
                              color=intervention_color,
                              label="Controlled Intervention")
                
                # Add shaded error region
                ax.fill_between(
                    x, 
                    [m - s for m, s in zip(y, sem)], 
                    [m + s for m, s in zip(y, sem)], 
                    alpha=0.15, 
                    color=intervention_color
                )
        
        # 2. Plot vanilla data (natural occurrences)
        if target in vanilla_data.get(architecture, {}):
            vanilla_target_data = vanilla_data[architecture][target]
            
            if vanilla_target_data['kl_values'] and vanilla_target_data['occurrence_counts']:
                # Calculate average KL value for each unique occurrence count
                unique_counts = sorted(set(vanilla_target_data['occurrence_counts']))
                avg_kl_values = []
                
                for count in unique_counts:
                    # Find all KL values for this occurrence count
                    indices = [i for i, c in enumerate(vanilla_target_data['occurrence_counts']) if c == count]
                    kl_values = [vanilla_target_data['kl_values'][i] for i in indices]
                    avg_kl_values.append(np.mean(kl_values))
                
                # Plot natural occurrence line
                ax.plot(unique_counts, avg_kl_values,
                       marker=natural_marker,
                       linestyle=natural_style,
                       linewidth=2,
                       color=natural_color,
                       label="Natural Occurrence")
        
        ax.grid(True, linestyle='--', alpha=0.7)
        sns.despine(ax=ax)
        
        # Add legend to first subplot only
        if t_idx == 0:
            ax.legend(loc='best')
    
    # Hide unused subplots
    for j in range(t_idx+1, len(axes.flat)):
        axes.flat[j].set_visible(False)
    
    # Add title
    fig.suptitle(f'{automaton.title()} - {architecture.capitalize()} - {intervention_type.title()} Comparison')
    
    # Save figure
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Leave space for title
    plt.savefig(os.path.join(output_dir, f'{architecture}_{intervention_type}_comparison.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, f'{architecture}_{intervention_type}_comparison.pdf'))
    plt.close()


def main():
    parser = argparse.ArgumentParser(description="Compare KL decomposition between interventions and natural occurrences.")
    parser.add_argument('--model-dir', default='experiments_mix_atleastonce/models', help='Base directory containing model folders')
    parser.add_argument('--data-dir', default='experiments_mix_atleastonce/data', help='Base directory containing dataset files')
    parser.add_argument('--automaton', default='parity_free', help='Automaton name')
    parser.add_argument('--intervention-type', default='state', help='Type of intervention (state, symbol, transition/arc)')
    parser.add_argument('--targets', nargs='+', default=None, help='List of targets to compare')
    parser.add_argument('--architectures', nargs='+', default=['transformer', 'lstm', 'rnn'], help='List of architectures to compare')
    parser.add_argument('--num-seeds', type=int, default=10, help='Number of seeds to aggregate')
    parser.add_argument('--intervention-start', type=int, default=50, help='Starting intervention count')
    parser.add_argument('--intervention-end', type=int, default=2000, help='Ending intervention count')
    parser.add_argument('--intervention-step', type=int, default=100, help='Intervention count step size')
    parser.add_argument('--vanilla-run-ids', nargs='+', type=int, default=None, help='Run IDs to use for vanilla data')
    parser.add_argument('--output-dir', default=None, help='Directory to save plots')
    parser.add_argument('--automata-name', default=None, help='Name of automaton in register to use for state/symbol mapping')
    parser.add_argument("--at_least_once", action="store_true")
    args = parser.parse_args()
    
    # Set output directory if not provided
    if args.output_dir is None:
        args.output_dir = os.path.join(args.model_dir, "..", "plots", args.automaton, "kl_vs_natural_occurrence")
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Set default vanilla run IDs if not provided
    if args.vanilla_run_ids is None:
        args.vanilla_run_ids = [50, 150, 250, 350, 450, 550, 650, 750, 850, 950, 1050, 1150, 1250, 1350, 1450, 1550, 1650, 1750, 1850, 1950]
    
    # Initialize state and symbol maps
    state_map = None
    symbol_map = None
    
    # First try to get maps from automata registry if automata_name is provided
    if args.automata_name is not None:
        try:
            automata = AUTOMATA_REGISTER[args.automata_name]
            state_map = automata.state_map
            symbol_map = automata.symbol_map
            print(f"Using state and symbol mappings from {args.automata_name} automaton in registry")
        except KeyError:
            print(f"Warning: Automaton '{args.automata_name}' not found in register.")
        except AttributeError:
            print(f"Warning: Automaton '{args.automata_name}' does not have mapping attributes.")
    
    # Set plot style
    set_plot_style()
    
    # If targets not provided, try to discover available targets
    if args.targets is None:
        target_dirs = []
        for arch in args.architectures:
            base_path = os.path.join(args.model_dir, args.automaton, args.intervention_type)
            if os.path.exists(base_path):
                targets = [d for d in os.listdir(base_path) if d.startswith('target_')]
                target_dirs.extend([d.split('_')[1] for d in targets])
        
        args.targets = sorted(set(target_dirs))
        print(f"Discovered targets: {args.targets}")
    
    # Convert targets to strings for consistent handling
    args.targets = [str(t) for t in args.targets]
    
    # Dictionary to store intervention data by architecture and target
    intervention_data = {arch: {} for arch in args.architectures}
    
    # Dictionary to store vanilla data with natural occurrences by architecture and target
    vanilla_data = {arch: {} for arch in args.architectures}
    
    # Load intervention data for each architecture and target
    for arch in args.architectures:
        for target in args.targets:
            print(f"Processing {arch} intervention data for {args.intervention_type} target {target}...")
            
            data = load_target_kl_data(
                args.model_dir,
                args.automaton,
                args.intervention_type,
                target,
                arch,
                args.num_seeds,
                args.intervention_start,
                args.intervention_end,
                args.intervention_step
            )
            
            if data['intervention_counts']:
                intervention_data[arch][target] = data
            else:
                print(f"Warning: No valid intervention data found for {arch}, target {target}")
    
    # Load vanilla data with natural occurrences for each architecture
    for arch in args.architectures:
        print(f"Processing {arch} vanilla data with natural occurrences...")
        
        vanilla_arch_data = load_vanilla_data_with_occurrences(
            args.model_dir,
            args.data_dir,
            args.automaton,
            arch,
            args.intervention_type,
            args.targets,
            range(1, args.num_seeds + 1),
            args.vanilla_run_ids,
            args.at_least_once
        )
        
        # Check if we got any data
        has_data = False
        for target, data in vanilla_arch_data.items():
            if data['kl_values']:
                has_data = True
                vanilla_data[arch][target] = data
        
        if not has_data:
            print(f"Warning: No valid vanilla data with occurrences found for {arch}")

    # Create plots comparing intervention data and natural occurrences
    print(f"Generating plots for {args.intervention_type} intervention vs. natural occurrences...")
    plot_target_comparison(
        args.targets,
        args.architectures,
        intervention_data,
        vanilla_data,
        args.intervention_type,
        args.output_dir,
        args.automaton,
        state_map,
        symbol_map
    )
    
    # Create per-architecture scatter plots
    print(f"Generating per-architecture scatter plots...")
    for arch in args.architectures:
        plot_per_architecture_comparison(
            args.targets,
            arch,
            intervention_data,
            vanilla_data,
            args.intervention_type,
            args.output_dir,
            args.automaton,
            state_map,
            symbol_map
        )
    
    print(f"All plots saved to {args.output_dir}")

if __name__ == "__main__":
    main()