#!/usr/bin/env python3
import json
import glob
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager as fm
from tqdm import tqdm
from matplotlib.gridspec import GridSpec

# Directory paths
BASE_DIR = "experiments_random_final"
IDENTIFIER = "40st_10sym_redo"
MAIN_SEED = 123
ACCEPTANCE_PROB = 0.3
MODELS_ROOT = f"{BASE_DIR}/models_{IDENTIFIER}_{ACCEPTANCE_PROB}_{MAIN_SEED}"
DATASETS_ROOT = f"{BASE_DIR}/data/datasets"
PLOTS_ROOT = f"{BASE_DIR}/plots_{IDENTIFIER}_{MAIN_SEED}"

# Create output directories
os.makedirs(PLOTS_ROOT, exist_ok=True)

# For caching DataFrame
CACHE_DIR = f"{BASE_DIR}/cache"
os.makedirs(CACHE_DIR, exist_ok=True)
DF_CACHE = f"{CACHE_DIR}/{IDENTIFIER}_{MAIN_SEED}_df.pkl"

# Plot configuration
PLOT_DPI = 300
FIG_WIDTH = 24
FIG_HEIGHT = 20
# Set up fonts and styling - same as original script
import matplotlib as mpl
import pathlib

# Path to your fonts
font_dir = os.path.expanduser('~/.fonts/')

# Font files
times_regular = os.path.join(font_dir, 'Times.TTF')
times_bold = os.path.join(font_dir, 'Timesbd.TTF')
times_italic = os.path.join(font_dir, 'Timesi.TTF')
times_bold_italic = os.path.join(font_dir, 'Timesbi.TTF')

# Verify fonts exist
font_files = [times_regular, times_bold, times_italic, times_bold_italic]
for font_file in font_files:
    if not os.path.exists(font_file):
        print(f"Warning: Font file not found: {font_file}")

# Clear existing font cache
from matplotlib.font_manager import findfont, FontProperties
findfont(FontProperties(), rebuild_if_missing=True)

# Manually add font files to matplotlib's font manager
for font_file in font_files:
    if os.path.exists(font_file):
        fm.fontManager.addfont(font_file)

# Get all font family names to verify Times is now available
font_names = sorted([f.name for f in fm.fontManager.ttflist])
print("Available font families:")
for name in font_names:
    if 'time' in name.lower():
        print(f" - {name}")
sns.set(style="whitegrid", context="talk", font="Times New Roman")

# Set Times as the default font family globally - REDUCED FONT SIZES
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 24        # Reduced from 30
plt.rcParams['axes.titlesize'] = 26   # Reduced from 34
plt.rcParams['axes.labelsize'] = 24   # Reduced from 32
plt.rcParams['xtick.labelsize'] = 24  # Reduced from 32
plt.rcParams['ytick.labelsize'] = 24  # Reduced from 32
plt.rcParams['legend.fontsize'] = 22  # Reduced from 32
plt.rcParams['figure.titlesize'] = 28 # Reduced from 38
# Style setup
 
# Define regex patterns for file paths
KL_PATTERN = re.compile(
    r".*/" # Root path
    r"(?P<arch>lstm|transformer)/" # Architecture 
    r"(?P<type>vanilla|alo)/" # Type (vanilla or alo)
    r"(?P<seed>\d+)/" # Seed value
    r"(?P<top_seed>\d+)/" # Topology seed
    r"((?P<intervention>state|symbol)/(?P<target>\d+)/(?P<ic>\d+)/)?" # Intervention info (optional)
    r"(?P<mseed>\d+)/eval/decomposed_kls\.json$" # Model seed and file path
)

def parse_metadata(file_path):
    """Parse metadata from a file path using regex pattern."""
    match = KL_PATTERN.match(file_path)
    if match:
        metadata = match.groupdict()
        
        # Convert numeric fields to integers
        for key in ["seed", "top_seed", "target", "mseed", "ic"]:
            if metadata.get(key) is not None and metadata[key] is not None:
                try:
                    metadata[key] = int(metadata[key])
                except (ValueError, TypeError):
                    pass
        
        # Set defaults for missing values
        if metadata["intervention"] is None:
            metadata["intervention"] = "none"
            metadata["target"] = 0
            metadata["ic"] = 0
            metadata["semiring"] = "alo"  # Assume alo semiring
        else:
            metadata["semiring"] = "alo"
            
        return metadata
    
    # If no pattern matches
    print(f"Failed to match path: {file_path}")
    return None

def get_arcs_file_path(metadata):
    """Get path to arcs file based on metadata."""
    # Try to find the arcs.txt file in various possible locations
    possible_paths = []
    
    if metadata["type"] == "vanilla":
        possible_paths = [
            # Test directory
            os.path.join(DATASETS_ROOT, "vanilla", IDENTIFIER, 
                        str(metadata["seed"]), str(metadata["top_seed"]), "test", "arcs.txt"),
            # Train directory
            os.path.join(DATASETS_ROOT, "vanilla", IDENTIFIER, 
                        str(metadata["seed"]), str(metadata["top_seed"]), "train", "arcs.txt"),
            # Validation directory
            os.path.join(DATASETS_ROOT, "vanilla", IDENTIFIER, 
                        str(metadata["seed"]), str(metadata["top_seed"]), "validation", "arcs.txt")
        ]
    else:  # alo
        possible_paths = [
            # Test directory
            os.path.join(DATASETS_ROOT, "alo", IDENTIFIER, 
                        str(metadata["seed"]), str(metadata["top_seed"]), 
                        metadata["intervention"], str(metadata["target"]), "test", "arcs.txt"),
            # Train directory with IC
            os.path.join(DATASETS_ROOT, "alo", IDENTIFIER, 
                        str(metadata["seed"]), str(metadata["top_seed"]), 
                        metadata["intervention"], str(metadata["target"]), 
                        "train", str(metadata["ic"]), "arcs.txt"),
            # Validation directory with IC
            os.path.join(DATASETS_ROOT, "alo", IDENTIFIER, 
                        str(metadata["seed"]), str(metadata["top_seed"]), 
                        metadata["intervention"], str(metadata["target"]), 
                        "validation", str(metadata["ic"]), "arcs.txt")
        ]
    
    # Check each path
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    # Return None if no path exists
    return None

def get_counts(metadata):
    """
    Get occurrence counts for states and symbols from the arcs file.
    For 'alo' semiring, count unique occurrences per sequence.
    """
    # Initialize counters
    state_counts = {}
    symbol_counts = {}
    
    # Get path to arcs file
    arcs_path = get_arcs_file_path(metadata)
    
    # If no arcs file found, return empty counters
    if not arcs_path:
        return {}, {}
    
    # Parse the arcs file
    try:
        with open(arcs_path) as f:
            for line in f:
                try:
                    transitions = eval(line.strip())
                    
                    # For alo semiring, count unique occurrences per sequence
                    seen_states = set()
                    seen_symbols = set()
                    
                    for src, tgt, sym in transitions:
                        seen_states.add(src)
                        seen_symbols.add(sym)
                    
                    # Add final state
                    if transitions:
                        seen_states.add(transitions[-1][1])
                    
                    # Update counters
                    for state in seen_states:
                        state_counts[state] = state_counts.get(state, 0) + 1
                    for symbol in seen_symbols:
                        symbol_counts[symbol] = symbol_counts.get(symbol, 0) + 1
                except Exception as e:
                    print(f"Error parsing line in {arcs_path}: {e}")
                    continue
    except Exception as e:
        print(f"Error opening or reading {arcs_path}: {e}")
    
    return state_counts, symbol_counts

def load_experiment_data():
    """Load and process all experiment data from the JSON files."""
    # Check if cached dataframe exists
    if os.path.exists(DF_CACHE):
        print(f"Loading data from cache: {DF_CACHE}")
        try:
            df = pd.read_pickle(DF_CACHE)
            print(f"Loaded {len(df)} records from cache.")
            
            # Double-check the cache has intervention data
            intervention_count = len(df[df.intervention != "none"])
            print(f"Intervention records in cache: {intervention_count}")
            
            # If cache is good, return it
            if intervention_count > 0:
                return df
            else:
                print("Cache lacks intervention data, regenerating...")
        except Exception as e:
            print(f"Error loading from cache: {e}. Regenerating...")
    
    records = []
    
    # Find all KL result JSON files
    search_pattern = os.path.join(MODELS_ROOT, "**", "eval", "decomposed_kls.json")
    filenames = glob.glob(search_pattern, recursive=True)
    
    if not filenames:
        raise RuntimeError(f"No files found with pattern: {search_pattern}")
    
    print(f"Found {len(filenames)} KL files. Processing...")
    
    # Debug: count intervention files
    intervention_files = [f for f in filenames if "/state/" in f or "/symbol/" in f]
    print(f"Found {len(intervention_files)} intervention files")
    if intervention_files:
        print("Sample intervention files:")
        for f in intervention_files[:3]:
            print(f"  {f}")
    
    # Process each file
    for file_path in tqdm(filenames):
        # Parse metadata from file path
        metadata = parse_metadata(file_path)
        if not metadata:
            continue
        
        # Load KL divergence data
        try:
            with open(file_path) as f:
                data = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError) as e:
            print(f"Failed to parse JSON from {file_path}: {e}")
            continue
        
        # Extract KL data from the JSON
        state_contributions = data.get("state_contributions", {})
        symbol_contributions = data.get("symbol_contributions", {})
        total_kl = data.get("total_kl", 0)
        
        # Get counts for states and symbols
        state_counts, symbol_counts = get_counts(metadata)
        
        # Create records for each state contribution
        for state_id, kl in state_contributions.items():
            try:
                element_id = int(state_id)
                records.append({
                    **metadata,
                    "kind": "state",
                    "element": element_id,
                    "kl": float(kl),
                    "count": state_counts.get(element_id, 0),
                    "total_kl": total_kl,
                })
            except (ValueError, TypeError) as e:
                print(f"Error processing state {state_id}: {e}")
        
        # Create records for each symbol contribution
        for symbol_id, kl in symbol_contributions.items():
            # Handle special symbols like "<EOS>"
            if symbol_id == "<EOS>":
                element_id = symbol_id
            else:
                try:
                    # Convert to integer if possible
                    element_id = int(symbol_id) if symbol_id.isdigit() else symbol_id
                except (ValueError, AttributeError):
                    element_id = symbol_id
            
            records.append({
                **metadata,
                "kind": "symbol",
                "element": element_id,
                "kl": float(kl),
                "count": symbol_counts.get(element_id, 0),
                "total_kl": total_kl,
            })
    
    # Check if we found any records
    if not records:
        raise RuntimeError("No records found after processing JSON files.")
    
    # Convert to DataFrame
    df = pd.DataFrame(records)
    
    # Print summary of records
    print(f"Total records: {len(df)}")
    print(f"Records by kind: {df.groupby('kind').size()}")
    print(f"Records by architecture: {df.groupby('arch').size()}")
    print(f"Records by intervention type: {df.groupby('intervention').size()}")
    
    # Count intervention data specifically
    intervention_data = df[df.intervention != "none"]
    print(f"Intervention records: {len(intervention_data)}")
    print(f"Intervention by kind: {intervention_data.groupby('kind').size()}")
    print(f"Intervention by target: {intervention_data.groupby(['kind', 'target']).size()}")
    
    # Cache the dataframe
    df.to_pickle(DF_CACHE)
    print(f"Saved DataFrame to cache: {DF_CACHE}")
    
    return df


def calculate_global_limits(df):
    """Calculate global min and max KL values for consistent scaling across all plots."""
    print("Calculating global KL limits...")
    
    global_min_kl = float('inf')
    global_max_kl = float('-inf')
    
    # Process states
    df_states = df[df.kind == "state"]
    if not df_states.empty:
        # Get top states by intervention or KL
        intervention_df = df_states[df_states.intervention == "state"]
        if len(intervention_df) > 0:
            target_counts = intervention_df.groupby('target').size()
            unique_states = target_counts.sort_values(ascending=False).head(8).index.tolist()
        else:
            kl_by_state = df_states.groupby('element')['kl'].mean().sort_values(ascending=False)
            unique_states = kl_by_state[kl_by_state.index != 0].head(8).index.tolist()
            
        # Process each state
        for state in unique_states:
            df_state = df_states[df_states.element == state]
            if df_state.empty:
                continue
                
            # Process each architecture
            for arch in df_state.arch.unique():
                df_arch = df_state[df_state.arch == arch]
                if df_arch.empty:
                    continue
                    
                # Process vanilla data
                df_vanilla = df_arch[df_arch.intervention == "none"]
                df_vanilla_positive = df_vanilla[df_vanilla.kl > 0]
                if not df_vanilla_positive.empty:
                    global_min_kl = min(global_min_kl, df_vanilla_positive['kl'].min())
                    global_max_kl = max(global_max_kl, df_vanilla_positive['kl'].max())
                
                # Process intervention data
                df_intervention = df_arch[(df_arch.intervention == "state") & (df_arch.target == state)]
                if not df_intervention.empty:
                    global_min_kl = min(global_min_kl, df_intervention['kl'].min())
                    global_max_kl = max(global_max_kl, df_intervention['kl'].max())
    
    # Process symbols
    df_symbols = df[df.kind == "symbol"]
    if not df_symbols.empty:
        # Get top symbols by intervention or KL
        intervention_df = df_symbols[df_symbols.intervention == "symbol"]
        if len(intervention_df) > 0:
            target_counts = intervention_df.groupby('target').size()
            unique_symbols = target_counts.sort_values(ascending=False).head(9).index.tolist()
        else:
            kl_by_symbol = df_symbols.groupby('element')['kl'].mean().sort_values(ascending=False)
            kl_by_symbol = kl_by_symbol[kl_by_symbol > 0]
            unique_symbols = kl_by_symbol.head(9).index.tolist()
            
        # Process each symbol
        for symbol in unique_symbols:
            df_symbol = df_symbols[df_symbols.element == symbol]
            if df_symbol.empty:
                continue
                
            # Process each architecture
            for arch in df_symbol.arch.unique():
                df_arch = df_symbol[df_symbol.arch == arch]
                if df_arch.empty:
                    continue
                    
                # Process vanilla data
                df_vanilla = df_arch[df_arch.intervention == "none"]
                df_vanilla_positive = df_vanilla[df_vanilla.kl > 0]
                if not df_vanilla_positive.empty:
                    global_min_kl = min(global_min_kl, df_vanilla_positive['kl'].min())
                    global_max_kl = max(global_max_kl, df_vanilla_positive['kl'].max())
                
                # Process intervention data
                df_intervention = df_arch[(df_arch.intervention == "symbol") & (df_arch.target == symbol)]
                if not df_intervention.empty:
                    global_min_kl = min(global_min_kl, df_intervention['kl'].min())
                    global_max_kl = max(global_max_kl, df_intervention['kl'].max())
    
    # Fallback for cases with no data
    if global_min_kl == float('inf') or global_max_kl == float('-inf'):
        global_min_kl, global_max_kl = 0.00001, 0.01  # Small positive for log scale
    else:
        # Add some padding
        y_range = global_max_kl - global_min_kl
        global_min_kl = max(0.00001, global_min_kl - 0.05 * y_range)  # Ensure positive for log scale
        global_max_kl = global_max_kl + 0.05 * y_range
    
    print(f"Global y-axis limits: [{global_min_kl:.6f}, {global_max_kl:.6f}]")
    return global_min_kl, global_max_kl




def create_state_grid_plot(df, y_min, y_max):
    print("Creating state grid plot...")
    
    df_states = df[df.kind == "state"]
    if df_states.empty:
        print("No state data found")
        return

    intervention_df = df_states[df_states.intervention == "state"]
    if len(intervention_df) > 0:
        target_counts = intervention_df.groupby('target').size()
        unique_states = target_counts.sort_values(ascending=False).index.tolist()
        unique_states = [s for s in unique_states if s != 7][:8]
        print(f"Using {len(unique_states)} intervention target states: {unique_states}")
    else:
        kl_by_state = df_states.groupby('element')['kl'].mean().sort_values(ascending=False)
        unique_states = kl_by_state[kl_by_state.index != 0].head(8).index.tolist()
        print(f"Using states with highest KL: {unique_states}")

    architectures = sorted(df_states.arch.unique())
    arch_colors = {'lstm': '#1f77b4', 'transformer': '#ff7f0e'}

    # A4-friendly layout (landscape)
    n_rows, n_cols = 2, 4
    FIG_WIDTH = 16  # inches
    FIG_HEIGHT = 9
    fig = plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT))
    gs = GridSpec(n_rows, n_cols, figure=fig)

    subplot_data = {}

    for state in unique_states:
        df_state = df_states[df_states.element == state]
        if df_state.empty:
            continue

        subplot_data[state] = {'vanilla': {}, 'intervention': {}}
        for arch in architectures:
            df_arch = df_state[df_state.arch == arch]
            if df_arch.empty:
                continue

            df_vanilla = df_arch[df_arch.intervention == "none"]
            df_vanilla_positive = df_vanilla[df_vanilla.kl > 0]
            if not df_vanilla_positive.empty:
                bin_size = 200
                count_bins = pd.cut(df_vanilla_positive['count'],
                                    bins=range(0, 2001, bin_size),
                                    labels=[i + bin_size / 2 for i in range(0, 2001 - bin_size, bin_size)])
                grouped = df_vanilla_positive.groupby(count_bins).agg({
                    'kl': ['mean', 'sem', 'count']
                }).reset_index()
                grouped.columns = ['count_bin', 'kl_mean', 'kl_sem', 'num_points']
                if not grouped.empty:
                    subplot_data[state]['vanilla'][arch] = grouped

            df_intervention = df_arch[(df_arch.intervention == "state") & (df_arch.target == state)]
            if not df_intervention.empty:
                grouped = df_intervention.groupby('ic').agg({
                    'kl': ['mean', 'sem', 'count']
                }).reset_index()
                grouped.columns = ['ic', 'kl_mean', 'kl_sem', 'num_points']
                grouped = grouped.sort_values('ic')
                if not grouped.empty:
                    subplot_data[state]['intervention'][arch] = grouped

    subplot_idx = 0
    for state in unique_states:
        if subplot_idx >= n_rows * n_cols:
            break

        if state not in subplot_data:
            continue

        row = subplot_idx // n_cols
        col = subplot_idx % n_cols
        ax = fig.add_subplot(gs[row, col])
        ax.set_yscale('log')  # Use log scale for all state plots
        plotted_something = False

        for arch in architectures:
            color = arch_colors[arch]
            if arch in subplot_data[state]['vanilla']:
                vanilla_data = subplot_data[state]['vanilla'][arch]
                if not vanilla_data.empty:
                    label = "LSTM" if arch == "lstm" else "Transformer"
                    ax.plot(vanilla_data['count_bin'], vanilla_data['kl_mean'],
                            marker='s', linestyle='--', label=f'{label} Obs.',
                            color=color, linewidth=2)
                    ax.fill_between(vanilla_data['count_bin'],
                                    vanilla_data['kl_mean'] - vanilla_data['kl_sem'],
                                    vanilla_data['kl_mean'] + vanilla_data['kl_sem'],
                                    alpha=0.2, color=color)
                    plotted_something = True

            if arch in subplot_data[state]['intervention']:
                int_data = subplot_data[state]['intervention'][arch]
                if not int_data.empty:
                    label = "LSTM" if arch == "lstm" else "Transformer"
                    ax.plot(int_data['ic'], int_data['kl_mean'],
                            marker='o', linestyle='-', label=f'{label} Interv.',
                            color=color, linewidth=2)
                    ax.fill_between(int_data['ic'],
                                    int_data['kl_mean'] - int_data['kl_sem'],
                                    int_data['kl_mean'] + int_data['kl_sem'],
                                    alpha=0.2, color=color)
                    plotted_something = True

        if not plotted_something:
            ax.text(0.5, 0.5, "No data available", ha='center', va='center', transform=ax.transAxes)
        
        # Always use the same y-axis limits for all subplots
        ax.set_ylim(0.0001, y_max)

        ax.set_title(f"State {state}")
        if row == n_rows - 1:
            ax.set_xlabel('Count / Interventions')
        # Remove individual y-labels
        # Instead, we'll add a common y-label for the figure
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.grid(True, alpha=0.3, linestyle='--')

        subplot_idx += 1

    # Add a single y-label for the entire figure
    fig.supylabel('Decomposed KL Divergence', fontsize=26, x=0.01)
    
    # Shared legend at bottom
    all_handles, all_labels = [], []
    for ax in fig.axes:
        handles, labels = ax.get_legend_handles_labels()
        all_handles += handles
        all_labels += labels
    by_label = dict(zip(all_labels, all_handles))
    fig.legend(by_label.values(), by_label.keys(),
               loc='lower center', ncol=4, fontsize=22, frameon=False, bbox_to_anchor=(0.5, 0.02))

    # Title and layout
    #fig.suptitle("States - Fixed Topology (40 states - 10 symbols) - Varying Weights", fontsize=26)
    plt.tight_layout(rect=[0.0, 0.05, 0.99, 0.98])  # Adjusted to leave space for supylabel

    output_path = os.path.join(PLOTS_ROOT, f"states_grid.png")
    fig.savefig(output_path, dpi=PLOT_DPI, bbox_inches='tight')
    print(f"State grid plot saved to {output_path}")
    plt.close(fig)


def create_symbol_grid_plot(df, y_min, y_max):
    """Create a grid plot for symbols, aggregating data across all topology seeds."""
    print("Creating symbol grid plot...")
    
    # Filter for symbol data
    df_symbols = df[df.kind == "symbol"]
    
    if df_symbols.empty:
        print("No symbol data found")
        return
    
    # Get symbols with intervention data
    intervention_df = df_symbols[df_symbols.intervention == "symbol"]
    
    if len(intervention_df) > 0:
        # Count interventions by target
        target_counts = intervention_df.groupby('target').size()
        print("Intervention counts by target symbol:")
        print(target_counts)
        
        # Use symbols that are intervention targets
        unique_symbols = target_counts.sort_values(ascending=False).head(9).index.tolist()
        print(f"Using {len(unique_symbols)} intervention target symbols: {unique_symbols}")
    else:
        print("No intervention data found for symbols")
        # Fallback: use symbols with highest KL contributions
        kl_by_symbol = df_symbols.groupby('element')['kl'].mean().sort_values(ascending=False)
        # Filter for positive KL values
        kl_by_symbol = kl_by_symbol[kl_by_symbol > 0]
        # Select top 9 symbols
        unique_symbols = kl_by_symbol.head(9).index.tolist()
        print(f"Using symbols with highest KL: {unique_symbols}")
    
    # Get architectures
    architectures = sorted(df_symbols.arch.unique())
    
    # Setup colors for different architectures
    arch_colors = {
        'lstm': '#1f77b4',       # blue 
        'transformer': '#ff7f0e'  # orange
    }
    
    # Grid size and figure setup
    grid_size = 3
    fig = plt.figure(figsize=(FIG_WIDTH, FIG_HEIGHT))
    gs = GridSpec(grid_size, grid_size, figure=fig)
    
    subplot_data = {}
    
    # First pass: process data
    for symbol in unique_symbols:
        # Filter data for this symbol
        df_symbol = df_symbols[df_symbols.element == symbol]
        
        if df_symbol.empty:
            continue
        
        subplot_data[symbol] = {'vanilla': {}, 'intervention': {}}
        
        # Process each architecture
        for arch in architectures:
            df_arch = df_symbol[df_symbol.arch == arch]
            
            if df_arch.empty:
                continue
                
            # Process vanilla data - aggregated across all topology seeds
            df_vanilla = df_arch[df_arch.intervention == "none"]
            if not df_vanilla.empty:
                # Filter for positive KL values
                df_vanilla_positive = df_vanilla[df_vanilla.kl > 0]
                
                if not df_vanilla_positive.empty:
                    # Group by count and compute stats
                    bin_size = 200  # Same as in the original script
                    count_bins = pd.cut(df_vanilla_positive['count'], 
                                     bins=range(0, 2001, bin_size),
                                     labels=[i + bin_size/2 for i in range(0, 2001-bin_size, bin_size)])
                    
                    grouped = df_vanilla_positive.groupby(count_bins).agg({
                        'kl': ['mean', 'sem', 'count']
                    }).reset_index()
                    
                    # Flatten multi-index columns
                    grouped.columns = ['count_bin', 'kl_mean', 'kl_sem', 'num_points']
                    
                    # Keep all data points
                    if not grouped.empty:
                        # Store processed vanilla data
                        subplot_data[symbol]['vanilla'][arch] = grouped
            
            # Process intervention data where this symbol is the target
            df_intervention = df_arch[(df_arch.intervention == "symbol") & (df_arch.target == symbol)]
            if not df_intervention.empty:
                print(f"Found {len(df_intervention)} intervention records for symbol {symbol}, arch {arch}")
                
                # Group by intervention count (ic)
                grouped = df_intervention.groupby('ic').agg({
                    'kl': ['mean', 'sem', 'count']
                }).reset_index()
                
                # Flatten multi-index columns
                grouped.columns = ['ic', 'kl_mean', 'kl_sem', 'num_points']
                
                # No filtering - keep all data
                if not grouped.empty:
                    # Sort by intervention count
                    grouped = grouped.sort_values('ic')
                    
                    # Store processed intervention data
                    subplot_data[symbol]['intervention'][arch] = grouped
    
    # Add a single y-label for the entire figure - placed slightly further in
    fig.supylabel('KL Divergence', fontsize=26, x=0.02)
    
    # Second pass: create the plots with consistent y-scales
    subplot_idx = 0
    for symbol_idx, symbol in enumerate(unique_symbols):
        if symbol_idx >= grid_size * grid_size:
            break
            
        if symbol not in subplot_data:
            continue
            
        # Calculate grid position
        row = subplot_idx // grid_size
        col = subplot_idx % grid_size
        
        # Create subplot
        ax = fig.add_subplot(gs[row, col])
        # Use log scale to match state plots
        ax.set_yscale('log')
        
        # Flag to check if we've plotted anything
        plotted_something = False
        
        # Plot data for each architecture
        for arch in architectures:
            color = arch_colors[arch]
            
            # Plot vanilla data
            if arch in subplot_data[symbol]['vanilla']:
                vanilla_data = subplot_data[symbol]['vanilla'][arch]
                
                if not vanilla_data.empty:
                    arch_label = "LSTM" if arch == "lstm" else "Transformer"
                    
                    # Plot vanilla data - observed KL by count
                    line = ax.plot(
                        vanilla_data['count_bin'],
                        vanilla_data['kl_mean'],
                        marker='s',
                        linestyle='--',
                        label=f'{arch_label} Observed',
                        color=color,
                        linewidth=2
                    )[0]
                    
                    # Add error bands
                    ax.fill_between(
                        vanilla_data['count_bin'],
                        vanilla_data['kl_mean'] - vanilla_data['kl_sem'],
                        vanilla_data['kl_mean'] + vanilla_data['kl_sem'],
                        alpha=0.2,
                        color=color
                    )
                    
                    plotted_something = True
            
            # Plot intervention data
            if arch in subplot_data[symbol]['intervention']:
                int_data = subplot_data[symbol]['intervention'][arch]
                
                if not int_data.empty:
                    arch_label = "LSTM" if arch == "lstm" else "Transformer"
                    
                    # Plot intervention data - KL by intervention count
                    line = ax.plot(
                        int_data['ic'],
                        int_data['kl_mean'],
                        marker='o',
                        linestyle='-',
                        label=f'{arch_label} Intervention',
                        color=color,
                        linewidth=2
                    )[0]
                    
                    # Add error bands
                    ax.fill_between(
                        int_data['ic'],
                        int_data['kl_mean'] - int_data['kl_sem'],
                        int_data['kl_mean'] + int_data['kl_sem'],
                        alpha=0.2,
                        color=color
                    )
                    
                    plotted_something = True
        
        # If we didn't plot anything, add a message
        if not plotted_something:
            ax.text(0.5, 0.5, "No data available", 
                   ha='center', va='center', transform=ax.transAxes)
        
        # Always use the same y-axis limits for all subplots
        ax.set_ylim(y_min, y_max)
        
        # Set title and labels
        symbol_label = str(symbol)
        ax.set_title(f"Symbol {symbol_label}")
        
        if row == grid_size - 1:
            ax.set_xlabel('Count / Interventions')
        # Remove individual y-labels since we now have a single supylabel
        
        # Remove top and right spines
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        
        # Add grid
        ax.grid(True, alpha=0.3, linestyle='--')
        
        # Add legend to first subplot only
        if subplot_idx == 0 and plotted_something:
            handles, labels = ax.get_legend_handles_labels()
            by_label = dict(zip(labels, handles))
            ax.legend(by_label.values(), by_label.keys(), fontsize=14)
            
        subplot_idx += 1
    
    # Add main title
    fig.suptitle(f"Symbols - {IDENTIFIER}", fontsize=28, y=0.98)
    
    # Reduce subplot spacing and margins
    plt.subplots_adjust(left=0.08, right=0.98, top=0.92, bottom=0.12, wspace=0.2, hspace=0.3)
    
    # Save figure
    output_path = os.path.join(PLOTS_ROOT, f"symbols_grid.png")
    fig.savefig(output_path, dpi=PLOT_DPI)
    print(f"Symbol grid plot saved to {output_path}")
    plt.close(fig)

def main():
    """Main execution function"""
    print(f"Starting visualization analysis for {IDENTIFIER}...")
    
    # Load and process data
    df = load_experiment_data()
    
    # Print summary statistics
    print(f"Total records: {len(df)}")
    
    # Calculate global y-axis limits for consistent scale across plots
    y_min, y_max = calculate_global_limits(df)
    
    # Create plots with the same y-axis scale
    create_state_grid_plot(df, y_min, y_max)
    create_symbol_grid_plot(df, y_min, y_max)
    
    print(f"Analysis complete. Plots saved to {PLOTS_ROOT}")
    
if __name__ == "__main__":
    main()