#!/usr/bin/env python3
import os
import json
import glob
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
from collections import defaultdict
from pathlib import Path

from intervention_sampling.automata_register import AUTOMATA_REGISTER


def load_multi_seed_kl_data(base_dir, automaton, intervention_type, target, num_seeds=10, 
                           intervention_start=50, intervention_end=2000, intervention_step=100,
                           architecture="transformer"):
    """Load KL decomposition data from multiple seeds and calculate statistics."""
    
    # Store aggregated data
    aggregated_data = {
        'intervention_counts': [],
        'total_kl': {'mean': [], 'std': [], 'values': []},
        'states': defaultdict(lambda: {'mean': [], 'std': [], 'values': []}),
        'symbols': defaultdict(lambda: {'mean': [], 'std': [], 'values': []}),
        'transitions': defaultdict(lambda: {'mean': [], 'std': [], 'values': []})
    }
    
    # Sets to track all unique elements
    all_states = set()
    all_symbols = set()
    all_transitions = set()
    
    # For each intervention count, collect data from all seeds
    for i in range(intervention_start, intervention_end + 1, intervention_step):
        # Data for this intervention count across all seeds
        intervention_data = {
            'total_kl': [],
            'states': defaultdict(list),
            'symbols': defaultdict(list),
            'transitions': defaultdict(list)
        }
        
        valid_seed_count = 0
        
        # Collect data from all seeds
        for seed in range(1, num_seeds + 1):
            # Get the correct path to the evaluation file
            # The full path structure is: 
            # base_dir/automaton/intervention_type/target_X/seed_Y/architecture_Z/eval/decomposed_kls.json
            eval_dir = os.path.join(
                base_dir, 
                automaton,
                intervention_type, 
                f"target_{target}", 
                f"seed_{seed}", 
                f"{architecture}_{i}", 
                "eval"
            )
            
            kl_file = os.path.join(eval_dir, "decomposed_kls.json")
            
            if os.path.exists(kl_file):
                with open(kl_file, 'r') as f:
                    data = json.load(f)
                
                valid_seed_count += 1
                
                # Store total KL
                intervention_data['total_kl'].append(data['total_kl'])
                
                # Store state contributions
                for state, value in data['state_contributions'].items():
                    all_states.add(state)
                    intervention_data['states'][state].append(value)
                
                # Store symbol contributions
                for symbol, value in data['symbol_contributions'].items():
                    all_symbols.add(symbol)
                    intervention_data['symbols'][symbol].append(value)
                
                # Store transition contributions
                for transition, value in data['transition_contributions'].items():
                    all_transitions.add(transition)
                    intervention_data['transitions'][transition].append(value)
            else:
                print(f"Warning: No KL data found for intervention count {i}, seed {seed}")
                print(f"Missing file: {kl_file}")
        
        # Only add this intervention count if we have data from at least one seed
        if valid_seed_count > 0:
            aggregated_data['intervention_counts'].append(i)
            
            # Calculate statistics for total KL
            total_kl_values = intervention_data['total_kl']
            aggregated_data['total_kl']['values'].append(total_kl_values)
            aggregated_data['total_kl']['mean'].append(np.mean(total_kl_values))
            aggregated_data['total_kl']['std'].append(np.std(total_kl_values, ddof=1) if len(total_kl_values) > 1 else 0)
            
            # Calculate statistics for state contributions
            for state in all_states:
                values = intervention_data['states'].get(state, [])
                if not values:
                    # Handle missing state data for this intervention count
                    aggregated_data['states'][state]['values'].append([np.nan])
                    aggregated_data['states'][state]['mean'].append(np.nan)
                    aggregated_data['states'][state]['std'].append(np.nan)
                else:
                    aggregated_data['states'][state]['values'].append(values)
                    aggregated_data['states'][state]['mean'].append(np.mean(values))
                    aggregated_data['states'][state]['std'].append(np.std(values, ddof=1) if len(values) > 1 else 0)
            
            # Calculate statistics for symbol contributions
            for symbol in all_symbols:
                values = intervention_data['symbols'].get(symbol, [])
                if not values:
                    aggregated_data['symbols'][symbol]['values'].append([np.nan])
                    aggregated_data['symbols'][symbol]['mean'].append(np.nan)
                    aggregated_data['symbols'][symbol]['std'].append(np.nan)
                else:
                    aggregated_data['symbols'][symbol]['values'].append(values)
                    aggregated_data['symbols'][symbol]['mean'].append(np.mean(values))
                    aggregated_data['symbols'][symbol]['std'].append(np.std(values, ddof=1) if len(values) > 1 else 0)
            
            # Calculate statistics for transition contributions
            for transition in all_transitions:
                values = intervention_data['transitions'].get(transition, [])
                if not values:
                    aggregated_data['transitions'][transition]['values'].append([np.nan])
                    aggregated_data['transitions'][transition]['mean'].append(np.nan)
                    aggregated_data['transitions'][transition]['std'].append(np.nan)
                else:
                    aggregated_data['transitions'][transition]['values'].append(values)
                    aggregated_data['transitions'][transition]['mean'].append(np.mean(values))
                    aggregated_data['transitions'][transition]['std'].append(np.std(values, ddof=1) if len(values) > 1 else 0)
    
    return aggregated_data, all_states, all_symbols, all_transitions


def set_plot_style():
    """Configure the plot style to use Times New Roman and other aesthetics."""
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman'],
        'font.size': 12,
        'axes.titlesize': 14,
        'axes.labelsize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 16
    })
    
    sns.set_style("ticks")  # Uses seaborn's despine functionality


def plot_total_kl_with_error(data, output_dir, title):
    """Plot the total KL divergence with error bars as a function of intervention count."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    x = data['intervention_counts']
    y = data['total_kl']['mean']
    yerr = data['total_kl']['std']
    
    # Calculate standard error of the mean from standard deviation
    sem = [std / np.sqrt(len(values)) for std, values in zip(yerr, data['total_kl']['values'])]
    
    # Plot the mean line
    line = ax.plot(x, y, marker='o', markersize=8, linewidth=2, color='#1f77b4')
    
    # Add shaded error region
    ax.fill_between(
        x, 
        [m - s for m, s in zip(y, sem)], 
        [m + s for m, s in zip(y, sem)], 
        alpha=0.3, 
        color='#1f77b4',
        label='Standard Error'
    )
    
    ax.set_xlabel('Intervention Count')
    ax.set_ylabel('Total KL Divergence')
    ax.set_title(f'Total KL Divergence vs. Intervention Count - {title}')
    ax.grid(True, linestyle='--', alpha=0.7)
    sns.despine()
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    plt.savefig(os.path.join(output_dir, 'total_kl_plot.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'total_kl_plot.pdf'))
    plt.close()


def plot_target_site_with_error(data, output_dir, target, intervention_type, state_map=None, symbol_map=None, title=""):
    """Plot the KL contribution of only the specifically targeted site (state, symbol, or transition)."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    x = data['intervention_counts']
    
    # Determine which data to plot based on intervention type
    if intervention_type == 'state':
        target_key = str(target)
        if target_key in data['states']:
            y = data['states'][target_key]['mean']
            yerr = data['states'][target_key]['std']
            site_values = data['states'][target_key]['values']
            
            # Get label from state map if available
            if state_map and target in state_map:
                site_label = f"State {target}: {state_map[target]}"
            else:
                site_label = f"State {target}"
        else:
            print(f"Warning: Target state {target} not found in data")
            return
    
    elif intervention_type == 'symbol':
        target_key = str(target)
        if target_key in data['symbols']:
            y = data['symbols'][target_key]['mean']
            yerr = data['symbols'][target_key]['std']
            site_values = data['symbols'][target_key]['values']
            
            # Get label from symbol map if available
            if symbol_map and target in symbol_map:
                site_label = f"Symbol {target}: {symbol_map[target]}"
            else:
                site_label = f"Symbol {target}"
        else:
            print(f"Warning: Target symbol {target} not found in data")
            return
    
    elif intervention_type == 'arc' or intervention_type == 'transition':
        # For transitions, we need to find the right key that contains this arc/transition index
        # This is more complex as the keys might be formatted differently
        found = False
        for transition_key in data['transitions']:
            # Try to extract transition index from various possible formats
            if "-" in transition_key:
                # For format like "state-symbol-state"
                try:
                    parts = transition_key.split("-")
                    # Check if any part contains the target number
                    if any(str(target) in part for part in parts):
                        target_key = transition_key
                        found = True
                        break
                except:
                    continue
            elif ":" in transition_key and ">" in transition_key:
                # For format like "state:symbol>state"
                try:
                    # Try to extract indices from this format
                    if str(target) in transition_key:
                        target_key = transition_key
                        found = True
                        break
                except:
                    continue
        
        if found:
            y = data['transitions'][target_key]['mean']
            yerr = data['transitions'][target_key]['std']
            site_values = data['transitions'][target_key]['values']
            site_label = f"Transition {target_key}"
        else:
            print(f"Warning: Target transition {target} not found in data")
            return
    
    else:
        print(f"Warning: Unknown intervention type: {intervention_type}")
        return
    
    # Calculate standard error of the mean
    sem = [std / np.sqrt(len(values)) for std, values in zip(yerr, site_values)]
    
    # Plot the mean line
    line = ax.plot(x, y, marker='o', markersize=8, linewidth=2, color='#1f77b4', label=site_label)
    
    # Add total KL for reference
    total_y = data['total_kl']['mean']
    total_yerr = data['total_kl']['std']
    total_sem = [std / np.sqrt(len(values)) for std, values in zip(total_yerr, data['total_kl']['values'])]
    
    ax.plot(x, total_y, marker='s', markersize=6, linewidth=2, linestyle='--', color='#ff7f0e', label='Total KL')
    
    # Add shaded error region for the target site
    ax.fill_between(
        x, 
        [m - s for m, s in zip(y, sem)], 
        [m + s for m, s in zip(y, sem)], 
        alpha=0.3, 
        color=line[0].get_color()
    )
    
    # Add shaded error region for total KL
    ax.fill_between(
        x, 
        [m - s for m, s in zip(total_y, total_sem)], 
        [m + s for m, s in zip(total_y, total_sem)], 
        alpha=0.2, 
        color='#ff7f0e'
    )
    
    ax.set_xlabel('Intervention Count')
    ax.set_ylabel('KL Contribution')
    ax.set_title(f'Target Site KL Contribution vs. Intervention Count - {title}')
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.legend(loc='best')
    sns.despine()
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    plt.savefig(os.path.join(output_dir, 'target_site_kl_plot.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'target_site_kl_plot.pdf'))
    plt.close()


def plot_state_contributions_with_error(data, states, output_dir, state_map=None, main_title=""):
    """Plot the state contributions with error bars as subplots with consistent y-axis scale."""
    n_states = len(states)
    
    # Determine subplot layout based on number of states
    if n_states <= 3:
        n_cols = n_states
        n_rows = 1
    else:
        n_cols = 3
        n_rows = (n_states + 2) // 3  # Ceiling division
    
    # Find global y-axis limits for consistent scaling
    y_min = float('inf')
    y_max = float('-inf')
    
    # Sort states for consistent ordering
    sorted_states = sorted(states, key=lambda x: int(x) if x.isdigit() else float('inf'))
    
    # Determine global y-axis limits
    for state in sorted_states:
        values = [v for v in data['states'][state]['mean'] if not np.isnan(v)]
        stderr = [s / np.sqrt(len(vals)) for s, vals in zip(data['states'][state]['std'], data['states'][state]['values']) if not np.isnan(s)]
        
        if values:
            min_with_err = min([v - e for v, e in zip(values, stderr)]) if stderr else min(values)
            max_with_err = max([v + e for v, e in zip(values, stderr)]) if stderr else max(values)
            y_min = min(y_min, min_with_err)
            y_max = max(y_max, max_with_err)
    
    # Add a small buffer to the limits for visual clarity
    y_range = y_max - y_min
    y_min = y_min - 0.05 * y_range if y_range > 0 else y_min * 0.95
    y_max = y_max + 0.05 * y_range if y_range > 0 else y_max * 1.05
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), sharex=True, sharey=True)
    
    # Convert to numpy array for easier indexing
    if n_rows == 1 and n_cols == 1:
        axes = np.array([axes])
    elif n_rows == 1 or n_cols == 1:
        axes = np.array(axes).reshape(-1)
    
    for i, state in enumerate(sorted_states):
        ax = axes.flat[i]
        
        x = data['intervention_counts']
        y = data['states'][state]['mean']
        yerr = data['states'][state]['std']
        
        # Calculate standard error of the mean
        sem = [std / np.sqrt(len(values)) for std, values in zip(yerr, data['states'][state]['values'])]
        
        # Plot the mean line
        line = ax.plot(x, y, marker='o', markersize=6, linewidth=2)
        
        # Add shaded error region
        ax.fill_between(
            x, 
            [m - s for m, s in zip(y, sem)], 
            [m + s for m, s in zip(y, sem)], 
            alpha=0.3, 
            color=line[0].get_color()
        )
        
        # Use state_map for title if available
        if state_map and int(state) in state_map:
            title = f'State {state}: {state_map[int(state)]}'
        else:
            title = f'State {state}'
            
        ax.set_title(title)
        ax.set_xlabel('Intervention Count')
        ax.set_ylabel('KL Contribution')
        ax.set_ylim(y_min, y_max)
        ax.grid(True, linestyle='--', alpha=0.7)
        sns.despine(ax=ax)

    # Hide unused subplots
    for j in range(i+1, len(axes.flat)):
        axes.flat[j].set_visible(False)
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    fig.suptitle(main_title)
    plt.savefig(os.path.join(output_dir, 'state_contributions_plot.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'state_contributions_plot.pdf'))
    plt.close()


def plot_symbol_contributions_with_error(data, symbols, output_dir, symbol_map=None, main_title=""):
    """Plot the symbol contributions with error bars as subplots with consistent y-axis scale."""
    n_symbols = len(symbols)
    
    # Determine subplot layout
    if n_symbols <= 3:
        n_cols = n_symbols
        n_rows = 1
    else:
        n_cols = 3
        n_rows = (n_symbols + 2) // 3
    
    # Sort symbols for consistent ordering (putting <EOS> last)
    sorted_symbols = sorted(symbols, key=lambda x: float('inf') if x == "<EOS>" else int(x) if x.isdigit() else x)
    
    # Find global y-axis limits for consistent scaling
    y_min = float('inf')
    y_max = float('-inf')
    
    # Determine global y-axis limits
    for symbol in sorted_symbols:
        values = [v for v in data['symbols'][symbol]['mean'] if not np.isnan(v)]
        stderr = [s / np.sqrt(len(vals)) for s, vals in zip(data['symbols'][symbol]['std'], data['symbols'][symbol]['values']) if not np.isnan(s)]
        
        if values:
            min_with_err = min([v - e for v, e in zip(values, stderr)]) if stderr else min(values)
            max_with_err = max([v + e for v, e in zip(values, stderr)]) if stderr else max(values)
            y_min = min(y_min, min_with_err)
            y_max = max(y_max, max_with_err)
    
    # Add a small buffer to the limits for visual clarity
    y_range = y_max - y_min
    y_min = y_min - 0.05 * y_range if y_range > 0 else y_min * 0.95
    y_max = y_max + 0.05 * y_range if y_range > 0 else y_max * 1.05
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), sharex=True, sharey=True)
    
    # Convert to numpy array for easier indexing
    if n_rows == 1 and n_cols == 1:
        axes = np.array([axes])
    elif n_rows == 1 or n_cols == 1:
        axes = np.array(axes).reshape(-1)
    
    for i, symbol in enumerate(sorted_symbols):
        ax = axes.flat[i]
        
        x = data['intervention_counts']
        y = data['symbols'][symbol]['mean']
        yerr = data['symbols'][symbol]['std']
        
        # Calculate standard error of the mean
        sem = [std / np.sqrt(len(values)) for std, values in zip(yerr, data['symbols'][symbol]['values'])]
        
        # Plot the mean line
        line = ax.plot(x, y, marker='o', markersize=6, linewidth=2)
        
        # Add shaded error region
        ax.fill_between(
            x, 
            [m - s for m, s in zip(y, sem)], 
            [m + s for m, s in zip(y, sem)], 
            alpha=0.3, 
            color=line[0].get_color()
        )
        
        # Use symbol_map for title if available
        symbol_alt = symbol
        if symbol.isnumeric():
            symbol_alt = int(symbol)
        if symbol_map and symbol_alt in symbol_map:
            title = f'Symbol {symbol}: {symbol_map[symbol_alt]}'
        else:
            title = f'Symbol {symbol}'
            
        ax.set_title(title)
        ax.set_xlabel('Intervention Count')
        ax.set_ylabel('KL Contribution')
        ax.set_ylim(y_min, y_max)
        ax.grid(True, linestyle='--', alpha=0.7)
        sns.despine(ax=ax)
    
    # Hide unused subplots
    for j in range(i+1, len(axes.flat)):
        axes.flat[j].set_visible(False)
    
    fig.suptitle(main_title)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    plt.savefig(os.path.join(output_dir, 'symbol_contributions_plot.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'symbol_contributions_plot.pdf'))
    plt.close()


def plot_transition_contributions_with_error(data, transitions, output_dir, state_map=None, symbol_map=None, main_title=""):
    """Plot the transition contributions with error bars as subplots with consistent y-axis scale."""
    n_transitions = len(transitions)
    
    # Determine optimal layout based on number of transitions
    if n_transitions <= 3:
        # For 1-3 transitions, use a single row
        n_cols = n_transitions
        n_rows = 1
        figsize = (5*n_cols, 5)  # Wider, single-row figure
    else:
        # For many transitions, use a grid layout
        n_cols = 3
        n_rows = min(4, (n_transitions + 2) // 3)
        figsize = (15, 12)
    
    max_transitions = n_rows * n_cols
    
    # If we have too many transitions, create multiple figures
    n_figures = (n_transitions + max_transitions - 1) // max_transitions
    
    # Sort transitions for consistent ordering
    sorted_transitions = sorted(transitions)
    
    # Find global y-axis limits for all transitions
    y_min = float('inf')
    y_max = float('-inf')
    
    for transition in sorted_transitions:
        values = [v for v in data['transitions'][transition]['mean'] if not np.isnan(v)]
        stderr = [s / np.sqrt(len(vals)) for s, vals in zip(data['transitions'][transition]['std'], data['transitions'][transition]['values']) if not np.isnan(s)]
        
        if values:
            min_with_err = min([v - e for v, e in zip(values, stderr)]) if stderr else min(values)
            max_with_err = max([v + e for v, e in zip(values, stderr)]) if stderr else max(values)
            y_min = min(y_min, min_with_err)
            y_max = max(y_max, max_with_err)
    
    # Add a small buffer to the limits for visual clarity
    y_range = y_max - y_min
    y_min = y_min - 0.05 * y_range if y_range > 0 else y_min * 0.95
    y_max = y_max + 0.05 * y_range if y_range > 0 else y_max * 1.05
    
    for fig_idx in range(n_figures):
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, sharex=True, sharey=True)
        
        # Handle different subplot array shapes
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        elif n_rows == 1:
            axes = np.array(axes).reshape(1, -1)
        
        # Convert to flattened array for easier indexing
        axes_flat = axes.flatten()
        
        start_idx = fig_idx * max_transitions
        end_idx = min((fig_idx + 1) * max_transitions, n_transitions)
        
        for i, transition_idx in enumerate(range(start_idx, end_idx)):
            transition = sorted_transitions[transition_idx]
            ax = axes_flat[i]
            
            x = data['intervention_counts']
            y = data['transitions'][transition]['mean']
            yerr = data['transitions'][transition]['std']
            
            # Calculate standard error of the mean
            sem = [std / np.sqrt(len(values)) for std, values in zip(yerr, data['transitions'][transition]['values'])]
            
            # Plot the mean line
            line = ax.plot(x, y, marker='o', markersize=6, linewidth=2)
            
            # Add shaded error region
            ax.fill_between(
                x, 
                [m - s for m, s in zip(y, sem)], 
                [m + s for m, s in zip(y, sem)], 
                alpha=0.3, 
                color=line[0].get_color()
            )

            # Format transition title with state and symbol maps if available
            # Transitions are typically in format "state1-symbol-state2" or might be "state1:symbol>state2"
            if "-" in transition:
                tspl = transition.split("-")
                state_part = tspl[0]
                next_state = tspl[2][1:]  # Remove the "s" prefix if present
                symbol = tspl[1]
                
                # Apply maps if available
                symbol_alt = symbol
                if symbol.isnumeric():
                    symbol_alt = int(symbol)
                next_state_alt = next_state
                if next_state.isnumeric():
                    next_state_alt = int(next_state)

                if state_map:
                    state_label = state_map.get(int(state_part), state_part)
                    next_state_label = state_map.get(int(next_state_alt) if isinstance(next_state_alt, str) and next_state_alt.isdigit() else next_state_alt, next_state_alt)
                else:
                    state_label = state_part
                    next_state_label = next_state
                
                if symbol_map:
                    symbol_label = symbol_map.get(symbol_alt, symbol)
                else:
                    symbol_label = symbol
                
                title = f'{state_label}:{symbol_label}→{next_state_label}'
            elif ">" in transition:
                # Parse transition string for alternate format
                state_part, next_part = transition.split(":")
                symbol, next_state = next_part.split(">")
                
                # Apply maps if available
                symbol_alt = symbol
                if symbol.isnumeric():
                    symbol_alt = int(symbol)
                next_state_alt = next_state
                if next_state.isnumeric():
                    next_state_alt = int(next_state)

                if state_map:
                    state_label = state_map.get(int(state_part), state_part)
                    next_state_label = state_map.get(int(next_state_alt) if isinstance(next_state_alt, str) and next_state_alt.isdigit() else next_state_alt, next_state_alt)
                else:
                    state_label = state_part
                    next_state_label = next_state
                
                if symbol_map:
                    symbol_label = symbol_map.get(symbol_alt, symbol)
                else:
                    symbol_label = symbol
                
                title = f'{state_label}:{symbol_label}→{next_state_label}'
            else:
                title = f'Transition {transition}'
            
            ax.set_title(title)
            ax.set_xlabel('Intervention Count')
            ax.set_ylabel('KL Contribution')
            ax.set_ylim(y_min, y_max)
            ax.grid(True, linestyle='--', alpha=0.7)
            sns.despine(ax=ax)
        
        # Hide unused subplots
        for j in range(i+1, len(axes_flat)):
            axes_flat[j].set_visible(False)
        
        fig.suptitle(main_title)
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        filename_suffix = f"_part{fig_idx+1}" if n_figures > 1 else ""
        plt.savefig(os.path.join(output_dir, f'transition_contributions_plot{filename_suffix}.png'), dpi=300)
        plt.savefig(os.path.join(output_dir, f'transition_contributions_plot{filename_suffix}.pdf'))
        plt.close()


def create_summary_dataframe(data, states, symbols, transitions, state_map=None, symbol_map=None):
    """Create a summary DataFrame with all the data including means and standard errors."""
    # Base data with intervention counts and total KL
    summary_data = {
        'intervention_count': data['intervention_counts'],
        'total_kl_mean': data['total_kl']['mean'],
        'total_kl_std': data['total_kl']['std'],
        'total_kl_sem': [std / np.sqrt(len(vals)) for std, vals in zip(data['total_kl']['std'], data['total_kl']['values'])]
    }
    
    # Add state data with mapped names in column descriptions
    for state in sorted(states, key=lambda x: int(x) if x.isdigit() else float('inf')):
        state_name = state_map.get(int(state), state) if state_map and state.isdigit() else state
        summary_data[f'state_{state}_{state_name}_mean'] = data['states'][state]['mean']
        summary_data[f'state_{state}_{state_name}_std'] = data['states'][state]['std']
        summary_data[f'state_{state}_{state_name}_sem'] = [
            std / np.sqrt(len(vals)) for std, vals in zip(data['states'][state]['std'], data['states'][state]['values'])
        ]
    
    # Add symbol data with mapped names
    for symbol in sorted(symbols, key=lambda x: float('inf') if x == "<EOS>" else int(x) if x.isdigit() else x):
        symbol_alt = symbol
        if symbol.isnumeric():
            symbol_alt = int(symbol)
        symbol_name = symbol_map.get(symbol_alt, symbol) if symbol_map else symbol
        summary_data[f'symbol_{symbol}_{symbol_name}_mean'] = data['symbols'][symbol]['mean']
        summary_data[f'symbol_{symbol}_{symbol_name}_std'] = data['symbols'][symbol]['std']
        summary_data[f'symbol_{symbol}_{symbol_name}_sem'] = [
            std / np.sqrt(len(vals)) for std, vals in zip(data['symbols'][symbol]['std'], data['symbols'][symbol]['values'])
        ]
    
    # Add transition data
    for transition in sorted(transitions):
        # Replace special characters in column names
        clean_transition = transition.replace('-', '_').replace('>', '_to_').replace('<', '').replace(':', '_')
        summary_data[f'transition_{clean_transition}_mean'] = data['transitions'][transition]['mean']
        summary_data[f'transition_{clean_transition}_std'] = data['transitions'][transition]['std']
        summary_data[f'transition_{clean_transition}_sem'] = [
            std / np.sqrt(len(vals)) for std, vals in zip(data['transitions'][transition]['std'], data['transitions'][transition]['values'])
        ]
    
    # Create DataFrame
    df = pd.DataFrame(summary_data)
    return df


def main():
    parser = argparse.ArgumentParser(description="Plot KL decomposition results with multiple seeds.")
    parser.add_argument('--base-dir', default='experiments_mix_atleastonce/models', help='Base directory containing model folders')
    parser.add_argument('--automaton', default='parity_free', help='Automaton name')
    parser.add_argument('--intervention-type', default='state', help='Type of intervention (state, symbol, arc)')
    parser.add_argument('--target', type=int, default=1, help='Target for intervention')
    parser.add_argument('--num-seeds', type=int, default=10, help='Number of seeds to aggregate')
    parser.add_argument('--intervention-start', type=int, default=50, help='Starting intervention count')
    parser.add_argument('--intervention-end', type=int, default=2000, help='Ending intervention count')
    parser.add_argument('--intervention-step', type=int, default=100, help='Intervention count step size')
    parser.add_argument('--architecture', default='transformer', help='Model architecture (transformer, lstm, etc.)')
    parser.add_argument('--output-dir', default=None, help='Directory to save plots (defaults to a subdirectory of base-dir)')
    parser.add_argument('--automata-name', default=None, help='Name of automaton in register to use for state/symbol mapping')
    args = parser.parse_args()
    
    # Set output directory if not provided
    if args.output_dir is None:
        args.output_dir = os.path.join(args.base_dir, "..", "plots", args.automaton, 
                                      args.intervention_type, f"target_{args.target}", 
                                      f"{args.architecture}_plots")
    
    # Initialize state and symbol maps if automata name is provided
    state_map = None
    symbol_map = None
    
    if args.automata_name is not None:
        try:
            automata = AUTOMATA_REGISTER[args.automata_name]
            state_map = automata.state_map
            symbol_map = automata.symbol_map
            print(f"Using state and symbol mappings from {args.automata_name} automaton")
        except KeyError:
            print(f"Warning: Automaton '{args.automata_name}' not found in register. Using default labels.")
        except AttributeError:
            print(f"Warning: Automaton '{args.automata_name}' does not have mapping attributes. Using default labels.")

    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Set plot style
    set_plot_style()
    
    title = f"{args.automaton} - {args.intervention_type} - Target {args.target} - {args.architecture}"

    # Load data from multiple seeds
    print(f"Loading KL data from {args.base_dir}...")
    data, states, symbols, transitions = load_multi_seed_kl_data(
        args.base_dir,
        args.automaton,
        args.intervention_type,
        args.target,
        args.num_seeds,
        args.intervention_start,
        args.intervention_end,
        args.intervention_step,
        args.architecture
    )
    
    if not data['intervention_counts']:
        print("No data found. Check if the KL decomposition files exist.")
        return
    
    # Generate plots with error bands
    print("Generating total KL plot with error bands...")
    plot_total_kl_with_error(data, args.output_dir, title)

    print("Generating state contribution plots with error bands...")
    plot_state_contributions_with_error(data, states, args.output_dir, state_map, title)
    
    print("Generating symbol contribution plots with error bands...")
    plot_symbol_contributions_with_error(data, symbols, args.output_dir, symbol_map, title)
    
    print("Generating transition contribution plots with error bands...")
    plot_transition_contributions_with_error(data, transitions, args.output_dir, state_map, symbol_map, title)
    
    # Generate the special target site plot
    print(f"Generating special plot for target {args.target} of type {args.intervention_type}...")
    plot_target_site_with_error(
        data, 
        args.output_dir,
        args.target,
        args.intervention_type,
        state_map,
        symbol_map,
        title
    )
    
    print(f"All plots saved to {args.output_dir}")
    
    # Create a summary CSV file with all the data including means and standard errors
    df = create_summary_dataframe(data, states, symbols, transitions, state_map, symbol_map)
    csv_path = os.path.join(args.output_dir, 'kl_decomposition_summary.csv')
    df.to_csv(csv_path, index=False)
    print(f"Summary data saved to {csv_path}")
    
    # Save raw data for future reference
    raw_data = {
        'intervention_counts': data['intervention_counts'],
        'total_kl': {
            'mean': data['total_kl']['mean'],
            'std': data['total_kl']['std'],
            'values': [list(vals) for vals in data['total_kl']['values']]
        }
    }
    
    # Convert values to lists for JSON serialization
    for category in ['states', 'symbols', 'transitions']:
        raw_data[category] = {}
        for key in data[category]:
            raw_data[category][key] = {
                'mean': data[category][key]['mean'],
                'std': data[category][key]['std'],
                'values': [list(vals) for vals in data[category][key]['values']]
            }
    
    # Save raw data as JSON
    raw_path = os.path.join(args.output_dir, 'kl_raw_data.json')
    with open(raw_path, 'w') as f:
        json.dump(raw_data, f, indent=2)
    print(f"Raw data saved to {raw_path}")


if __name__ == "__main__":
    main()