#!/usr/bin/env python3
import json
import glob
import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager as fm
from tqdm import tqdm

# Directory paths
BASE_DIR = "experiments_random_100_10"
MODELS_ROOT = os.path.join(BASE_DIR, "models")
DATASETS_ROOT = os.path.join(BASE_DIR, "data/datasets")
PLOTS_ROOT = os.path.join(BASE_DIR, "plots")

# Parameters from train script
INTERVENTION_START = 50
INTERVENTION_END = 2000
INTERVENTION_STEP = 200
BIN_SIZE = 200  # For binning vanilla results

# Store DataFrame for reuse
HASH = MODELS_ROOT.replace("/", "__")
DF_STORE = f"df_storage/{HASH}.pkl"

# Create output directories
os.makedirs(PLOTS_ROOT, exist_ok=True)
os.makedirs("df_storage", exist_ok=True)

# Plot configuration
PLOT_DPI = 300  # Higher resolution for publications
FIG_WIDTH = 24
FIG_HEIGHT = 8

# Set up fonts and styling
import matplotlib as mpl
import pathlib

# Path to your fonts
font_dir = os.path.expanduser('~/.fonts/')

# Font files
times_regular = os.path.join(font_dir, 'Times.TTF')
times_bold = os.path.join(font_dir, 'Timesbd.TTF')
times_italic = os.path.join(font_dir, 'Timesi.TTF')
times_bold_italic = os.path.join(font_dir, 'Timesbi.TTF')

# Verify fonts exist
font_files = [times_regular, times_bold, times_italic, times_bold_italic]
for font_file in font_files:
    if not os.path.exists(font_file):
        print(f"Warning: Font file not found: {font_file}")

# Clear existing font cache (in case you've tried to add these fonts before)
from matplotlib.font_manager import findfont, FontProperties
findfont(FontProperties(), rebuild_if_missing=True)

# Manually add font files to matplotlib's font manager
for font_file in font_files:
    if os.path.exists(font_file):
        fm.fontManager.addfont(font_file)

# Get all font family names to verify Times is now available
font_names = sorted([f.name for f in fm.fontManager.ttflist])
print("Available font families:")
for name in font_names:
    if 'time' in name.lower():
        print(f" - {name}")
sns.set(style="whitegrid", context="talk", font="Times New Roman")

# Set Times as the default font family globally
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 26        # Base font size (minimum size requested)
plt.rcParams['axes.titlesize'] = 35   # Subplot title size (scaled proportionally)
plt.rcParams['axes.labelsize'] = 32   # Axis label size (scaled proportionally)
plt.rcParams['xtick.labelsize'] = 26  # X-axis tick label size (minimum size)
plt.rcParams['ytick.labelsize'] = 26  # Y-axis tick label size (minimum size)
plt.rcParams['legend.fontsize'] = 26  # Legend font size (scaled proportionally)
plt.rcParams['figure.titlesize'] = 40 # Figure title size (scaled proportionally)

# Regex patterns
INTERVENTION_PATTERN_1 = re.compile(
    r".*/"  # Any directories leading up to models
    r"(?P<arch>lstm|transformer|rnn)/"  # Architecture
    r"(?P<semiring>alo|binning)/"  # Semiring type (alo or binning)
    r"(?P<num_states>\d+)st_(?P<num_symbols>\d+)sym/"  # Number of states and symbols
    r"(?P<seed>\d+)/"  # Seed value
    r"(?P<intervention>state|symbol)/"  # Intervention type
    r"(?P<target>\d+)/"  # Target
    r"(?P<intervention_count>\d+)/(?P<mseed>\d+)/eval/decomposed_kls\.json$"  # Intervention count and model seed
)

INTERVENTION_PATTERN_2 = re.compile(
    r".*/"  # Any directories leading up to models
    r"(?P<arch>lstm|transformer|rnn)/"  # Architecture
    r"(?P<type>intervention)/"  # Type is intervention
    r"(?P<semiring>alo|binning)/"  # Semiring type
    r"(?P<seed>\d+)/"  # Seed value
    r"(?P<intervention>state|symbol)/"  # Intervention type
    r"(?P<target>\d+)/"  # Target
    r"train/(?P<intervention_count>\d+)/"  # Intervention count
    r"(?P<mseed>\d+)/eval/decomposed_kls\.json$"  # Model seed
)

VANILLA_PATTERN = re.compile(
    r".*/"  # Any directories leading up to models
    r"(?P<arch>lstm|transformer|rnn)/"  # Architecture
    r"(?P<type>vanilla)/"  # Type is vanilla
    r"(?P<semiring>none)/"  # Semiring is none for vanilla
    r"(?P<seed>\d+)/"  # Seed value (was am_idx)
    r"(?P<intervention>none)/"  # Intervention is none
    r"(?P<target>none)/"  # Target is none
    r"train/(?P<intervention_count>\d+)/"  # Always 1 for vanilla
    r"(?P<mseed>\d+)/eval/decomposed_kls\.json$"  # Model seed
)

def parse_metadata(file_path):
    """Parse metadata from a file path using regex patterns."""
    # Try intervention pattern 1 first
    match = INTERVENTION_PATTERN_1.match(file_path)
    if match:
        metadata = match.groupdict()
        metadata["type"] = "intervention"
        
        # Convert numeric fields to integers
        for key in ["seed", "target", "mseed", "num_states", "num_symbols", "intervention_count"]:
            if metadata.get(key) is not None:
                try:
                    metadata[key] = int(metadata[key])
                except (ValueError, TypeError):
                    pass
        
        return metadata
    
    # Try intervention pattern 2 if first doesn't match
    match = INTERVENTION_PATTERN_2.match(file_path)
    if match:
        metadata = match.groupdict()
        
        # Convert numeric fields to integers
        for key in ["seed", "target", "mseed", "intervention_count"]:
            if metadata.get(key) is not None:
                try:
                    metadata[key] = int(metadata[key])
                except (ValueError, TypeError):
                    pass
        
        # Set num_states and num_symbols based on seed value or defaults
        metadata["num_states"] = 50
        metadata["num_symbols"] = 10
        
        return metadata
    
    # Try vanilla pattern if intervention patterns don't match
    match = VANILLA_PATTERN.match(file_path)
    if match:
        metadata = match.groupdict()
        
        # Convert numeric fields to integers
        for key in ["seed", "mseed", "intervention_count"]:
            if metadata.get(key) is not None and metadata[key] != "none":
                try:
                    metadata[key] = int(metadata[key])
                except (ValueError, TypeError):
                    pass
        
        # For vanilla, set target to 0 (conventional default)
        metadata["target"] = 0
        
        # We need to set the semiring to both alo and binning for vanilla runs
        # This will be handled in the load_experiment_data function
        
        # Set num_states and num_symbols based on seed value
        metadata["num_states"] = 50
        metadata["num_symbols"] = 10
        
        return metadata
    
    # If no pattern matches
    print(f"Failed to match path: {file_path}")
    return None

def get_counts(metadata):
    """
    Get occurrence counts for states and symbols from the arcs file.
    Different counting methods for different semirings.
    """
    # Initialize counters
    state_counts = {}
    symbol_counts = {}
    
    seed_val = metadata["seed"]
    num_states = metadata.get("num_states", 50)
    num_symbols = metadata.get("num_symbols", 10)
    semiring = metadata.get("semiring", "alo")
    
    # Path to the arcs file - try multiple potential locations
    arcs_paths = []
    
    if metadata["type"] == "vanilla":
        # For vanilla models, try several possible locations
        arcs_paths = [
            # Vanilla-specific directory
            os.path.join(DATASETS_ROOT, "vanilla", "train", str(seed_val), "arcs.txt"),
            # Within vanilla directory in type subdirectory
            os.path.join(DATASETS_ROOT, "vanilla", "train", str(seed_val), str(metadata["intervention_count"]), "arcs.txt"),
            # Alternative vanilla structure
            os.path.join(DATASETS_ROOT, "alo", f"{num_states}st_{num_symbols}sym", str(seed_val), "vanilla", "0", "train", "1", "arcs.txt")
        ]
    else:
        # For intervention models, try both potential structures
        arcs_paths = [
            # Structure from train script
            os.path.join(DATASETS_ROOT, semiring, f"{num_states}st_{num_symbols}sym", str(seed_val), 
                        metadata["intervention"], str(metadata["target"]), "train", str(metadata["intervention_count"]), "arcs.txt"),
            # Alternative intervention structure
            os.path.join(DATASETS_ROOT, "intervention", semiring, str(seed_val), metadata["intervention"], 
                        str(metadata["target"]), "train", str(metadata["intervention_count"]), "arcs.txt")
        ]
    
    # Try each path until we find one that exists
    arcs_path = None
    for path in arcs_paths:
        if os.path.exists(path):
            arcs_path = path
            break
    
    # If no path exists, return empty counters
    if arcs_path is None:
        return {}, {}
    
    # Parse the arc file
    try:
        with open(arcs_path) as f:
            for line in f:
                try:
                    transitions = eval(line.strip())
                    
                    # Different counting based on semiring
                    if semiring == "binning":
                        # For binning semiring, count each occurrence
                        for src, tgt, sym in transitions:
                            state_counts[src] = state_counts.get(src, 0) + 1
                            symbol_counts[sym] = symbol_counts.get(sym, 0) + 1
                        
                        # Count final state
                        if transitions:
                            final_state = transitions[-1][1]
                            state_counts[final_state] = state_counts.get(final_state, 0) + 1
                    else:  # 'alo' or other semirings
                        # Count unique occurrences per sequence
                        seen_states = set()
                        seen_symbols = set()
                        
                        for src, tgt, sym in transitions:
                            seen_states.add(src)
                            seen_symbols.add(sym)
                        
                        # Add final state
                        if transitions:
                            seen_states.add(transitions[-1][1])
                        
                        # Update counters
                        for state in seen_states:
                            state_counts[state] = state_counts.get(state, 0) + 1
                        for symbol in seen_symbols:
                            symbol_counts[symbol] = symbol_counts.get(symbol, 0) + 1
                except Exception as e:
                    print(f"Error parsing line in {arcs_path}: {e}")
                    continue
    except Exception as e:
        print(f"Error opening or reading {arcs_path}: {e}")
    
    return state_counts, symbol_counts

def load_experiment_data():
    """Load and process all experiment data from JSON files."""
    records = []
    vanilla_files_processed = 0
    intervention_files_processed = 0
    
    # Find all result JSON files
    search_pattern = os.path.join(MODELS_ROOT, "**", "eval", "decomposed_kls.json")
    
    # Try to load from file list if it exists
    file_list_path = os.path.join(BASE_DIR, "kl_files.txt")
    if os.path.exists(file_list_path):
        print(f"Loading file list from {file_list_path}")
        filenames = []
        with open(file_list_path) as f:
            for line in f.readlines():
                line = line.strip()
                if "models/" in line:
                    line = line.replace("models/", "")
                    filenames.append(os.path.join(MODELS_ROOT, line))
        print(f"Found {len(filenames)} files in file list")
    else:
        print(f"File list not found, using glob with pattern: {search_pattern}")
        filenames = glob.glob(search_pattern, recursive=True)
        print(f"Found {len(filenames)} files with glob")

    # Debug - print some filenames
    if filenames:
        print("Sample filenames:")
        for fname in filenames[:5]:
            print(f"  {fname}")
    
    for file_path in tqdm(filenames):
        # Parse metadata from file path
        metadata = parse_metadata(file_path)
        if not metadata:
            continue
        
        # Load KL divergence data
        try:
            with open(file_path) as f:
                data = json.load(f)
        except (json.JSONDecodeError, FileNotFoundError) as e:
            print(f"Failed to parse JSON from {file_path}: {e}")
            continue
        
        state_contributions = data.get("state_contributions", {})
        symbol_contributions = data.get("symbol_contributions", {})
        total_kl = data.get("total_kl", 0)
        
        # Process vanilla model data
        if metadata["type"] == "vanilla":
            vanilla_files_processed += 1
            
            # For vanilla, we explicitly process for both alo and binning semirings
            for semiring in ["alo", "binning"]:
                # Create a copy of metadata with the current semiring
                metadata_copy = metadata.copy()
                metadata_copy["semiring"] = semiring
                
                # Get counts for this model using the appropriate semiring's counting method
                state_counts, symbol_counts = get_counts(metadata_copy)
                
                # Create a record for each state contribution
                for state_id, kl in state_contributions.items():
                    element_id = int(state_id)
                    records.append({
                        **metadata_copy,
                        "kind": "state",
                        "element": element_id,
                        "kl": float(kl),
                        "count": state_counts.get(element_id, 0),
                        "total_kl": total_kl,
                    })

                # Create a record for each symbol contribution
                for symbol_id, kl in symbol_contributions.items():
                    # Convert symbol_id to int if possible
                    try:
                        if symbol_id.isdigit():
                            element_id = int(symbol_id)
                        else:
                            element_id = symbol_id
                    except (ValueError, AttributeError):
                        element_id = symbol_id
                    
                    records.append({
                        **metadata_copy,
                        "kind": "symbol",
                        "element": element_id,
                        "kl": float(kl),
                        "count": symbol_counts.get(element_id, 0),
                        "total_kl": total_kl,
                    })
                
                # Debug info
                if vanilla_files_processed <= 5:
                    print(f"Vanilla file #{vanilla_files_processed}, {file_path}:")
                    print(f"  Semiring: {semiring}")
                    print(f"  Created {len(state_contributions)} state records and {len(symbol_contributions)} symbol records")
        
        # Process intervention model data
        else:
            intervention_files_processed += 1
            
            if metadata["intervention"] == "state":
                # For state interventions, get the KL for the targeted state
                kl = float(state_contributions.get(str(metadata["target"]), 0))
            elif metadata["intervention"] == "symbol":
                # For symbol interventions, get the KL for the targeted symbol
                kl = float(symbol_contributions.get(str(metadata["target"]), 0))
            else:
                continue
            
            # Add record - use intervention_count directly as count
            records.append({
                **metadata,
                "kind": metadata["intervention"],
                "element": metadata["target"],
                "kl": kl,
                "count": metadata["intervention_count"],
                "total_kl": total_kl,
            })
    
    # Check if we found any records
    if not records:
        raise RuntimeError("No records found after processing JSON files. Check file paths and patterns.")
    
    print(f"Processed {vanilla_files_processed} vanilla files and {intervention_files_processed} intervention files")
    
    # Convert to DataFrame
    df = pd.DataFrame(records)
    
    # Make sure vanilla records have semiring=alo or semiring=binning
    # This is a safety check to fix any vanilla records that might have semiring=none
    df.loc[(df['type'] == 'vanilla') & (df['semiring'] == 'none'), 'semiring'] = 'alo'
    
    # Print summary of records by type and semiring
    print("\nSummary by type and semiring:")
    for kind in ["state", "symbol"]:
        for typ in ["vanilla", "intervention"]:
            for semiring in ["alo", "binning"]:
                count = len(df[(df.kind == kind) & (df.type == typ) & (df.semiring == semiring)])
                unique_elements = len(df[(df.kind == kind) & (df.type == typ) & 
                                      (df.semiring == semiring)].element.unique())
                print(f"{kind}-{typ}-{semiring}: {count} records for {unique_elements} unique elements")
    
    return df

def create_main_plots(df):
    """
    Create two main plots - one for states and one for symbols,
    with subplots for each semiring.
    """
    print("Creating main plots...")
    
    # Get all architectures (we expect lstm and transformer)
    architectures = sorted(df.arch.unique())
    semirings = ["alo", "binning"]
    
    # Print debug info about number of records and elements
    print("\nDetailed record counts:")
    for kind in ["state", "symbol"]:
        for semiring in semirings:
            for arch in architectures:
                vanilla_count = len(df[(df.kind == kind) & (df.semiring == semiring) & 
                                    (df.arch == arch) & (df.type == "vanilla")])
                vanilla_elements = len(df[(df.kind == kind) & (df.semiring == semiring) & 
                                      (df.arch == arch) & (df.type == "vanilla")].element.unique())
                intervention_count = len(df[(df.kind == kind) & (df.semiring == semiring) & 
                                        (df.arch == arch) & (df.type == "intervention")])
                print(f"{kind}-{semiring}-{arch}: Vanilla={vanilla_count} records for {vanilla_elements} elements, "
                      f"Intervention={intervention_count} records")
    
    # Process state and symbol separately
    for kind in ["state", "symbol"]:
        print(f"Processing {kind} plot...")
        
        # Create a figure with subplots for each semiring
        fig, axes = plt.subplots(1, 2, figsize=(FIG_WIDTH, FIG_HEIGHT))
        
        # Setup colors for different architectures - use exact colors from example script
        arch_colors = {
            'lstm': '#1f77b4',       # blue 
            'transformer': '#ff7f0e'  # orange
        }
        
        # Process each semiring
        for sem_idx, semiring in enumerate(semirings):
            ax = axes[sem_idx]
            
            # Filter data for current kind and semiring
            df_sem = df[(df.kind == kind) & (df.semiring == semiring)]
            
            if df_sem.empty:
                print(f"No data for {kind} and {semiring}")
                continue
            
            # Process each architecture
            for arch in architectures:
                # Filter for current architecture
                df_arch = df_sem[df_sem.arch == arch]
                
                if df_arch.empty:
                    continue
                
                # Process vanilla data
                df_vanilla = df_arch[df_arch.type == "vanilla"]
                if not df_vanilla.empty:
                    print(f"Processing {len(df_vanilla)} vanilla records for {kind}-{semiring}-{arch}")
                    
                    # Bin vanilla data by count
                    bin_size = BIN_SIZE
                    df_vanilla_positive = df_vanilla[df_vanilla.kl > 0]
                    
                    if not df_vanilla_positive.empty:
                        print(f"Found {len(df_vanilla_positive)} positive KL vanilla records")
                        
                        # Get range of counts
                        min_count = df_vanilla_positive['count'].min()
                        max_count = df_vanilla_positive['count'].max()
                        print(f"Count range: {min_count} - {max_count}")
                        
                        # Create bin edges and centers
                        bin_edges = np.arange(0, max_count + bin_size + 1, bin_size)
                        bin_centers = bin_edges[:-1] + bin_size / 2
                        
                        # Create bins and calculate statistics
                        binned_data = []
                        for i in range(len(bin_centers)):
                            lower = bin_edges[i]
                            upper = bin_edges[i+1]
                            
                            binned_df = df_vanilla_positive[(df_vanilla_positive['count'] >= lower) & 
                                                          (df_vanilla_positive['count'] < upper)]
                            
                            if not binned_df.empty:
                                binned_data.append({
                                    'count_bin': bin_centers[i],
                                    'kl_mean': binned_df['kl'].mean(),
                                    'kl_se': binned_df['kl'].sem() if len(binned_df) > 1 else 0,
                                    'num_points': len(binned_df)
                                })
                        
                        # Plot binned vanilla data
                        if binned_data:
                            print(f"Created {len(binned_data)} bins for plotting")
                            df_binned = pd.DataFrame(binned_data)

                            if arch == "lstm":
                                arch_label = "LSTM"
                            elif arch == "transformer":
                                arch_label = "Transformer"
                            
                            # Plot line - match style from example script
                            line = ax.plot(
                                df_binned['count_bin'],
                                df_binned['kl_mean'],
                                marker='s',  # square marker
                                linestyle='--',  # dashed line
                                label=f'{arch_label} Observed (binned)',
                                color=arch_colors[arch],
                                linewidth=2
                            )[0]
                            
                            # Add shaded area for standard error
                            ax.fill_between(
                                df_binned['count_bin'],
                                df_binned['kl_mean'] - df_binned['kl_se'],
                                df_binned['kl_mean'] + df_binned['kl_se'],
                                alpha=0.2,
                                color=arch_colors[arch]
                            )
                        else:
                            print(f"No bins created for {kind}-{semiring}-{arch}")
                    else:
                        print(f"No positive KL values found for {kind}-{semiring}-{arch} vanilla")
                else:
                    print(f"No vanilla data for {kind}-{semiring}-{arch}")
                
                # Process intervention data
                df_intervention = df_arch[(df_arch.type == "intervention") & (df_arch.intervention == kind)]
                if not df_intervention.empty:
                    print(f"Processing {len(df_intervention)} intervention records for {kind}-{semiring}-{arch}")
                    
                    # Group by intervention count
                    grouped = df_intervention.groupby('intervention_count').agg({
                        'kl': ['mean', 'sem']
                    }).reset_index()
                    
                    # Flatten the multi-index columns
                    grouped.columns = ['intervention_count', 'kl_mean', 'kl_sem']
                    
                    # Filter for positive values (for log scale)
                    positive_groups = grouped[grouped['kl_mean'] > 0]
                    
                    if not positive_groups.empty:
                        # Sort by intervention count
                        positive_groups = positive_groups.sort_values('intervention_count')

                        if arch == "lstm":
                            arch_label = "LSTM"
                        elif arch == "transformer":
                            arch_label = "Transformer"

                        # Plot intervention data as line - match style from example script
                        line = ax.plot(
                            positive_groups['intervention_count'],
                            positive_groups['kl_mean'],
                            marker='o',  # circle marker
                            linestyle='-',  # solid line
                            label=f'{arch_label} Intervention',
                            color=arch_colors[arch],
                            linewidth=2
                        )[0]
                        
                        # Add shaded area for standard error
                        ax.fill_between(
                            positive_groups['intervention_count'],
                            positive_groups['kl_mean'] - positive_groups['kl_sem'],
                            positive_groups['kl_mean'] + positive_groups['kl_sem'],
                            alpha=0.2,
                            color=arch_colors[arch]
                        )
            
            # Set subplot title
            if semiring == "alo":
                ax.set_title(f"At-least-once")
            else:
                ax.set_title(f"Binning")
            
            # Set labels for axes - match labels from example script
            ax.set_ylabel('Decomposed KL Divergence')
            ax.set_xlabel('Occurrence count')
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            # Set x-axis limits based on semiring
            if semiring == "alo":
                ax.set_xlim(100, 1500)
            else:  # binning
                ax.set_xlim(100, 1250)
            
            # Add grid
            ax.grid(True, alpha=0.3, linestyle='--')
            
            # Add legend only to first subplot
            if sem_idx == 0:
                # Remove duplicate labels in legend
                handles, labels = ax.get_legend_handles_labels()
                by_label = dict(zip(labels, handles))
                ax.legend(by_label.values(), by_label.keys())
        
        # Set consistent y-axis range and log scale
        # set_consistent_y_range(fig, axes)
        
        # No main title as per example script
        
        # Adjust layout
        plt.tight_layout()
        
        # Save figure
        output_path = os.path.join(PLOTS_ROOT, f"{kind}_kl_vs_count.png")
        fig.savefig(output_path, dpi=PLOT_DPI)
        plt.close(fig)

def create_combined_plot(df):
    """
    Create a combined plot with state and symbol data in a single row of four plots,
    maintaining the same overall dimensions as the original plots.
    """
    print("Creating combined state and symbol plot...")
    
    # Get all architectures (we expect lstm and transformer)
    architectures = sorted(df.arch.unique())
    semirings = ["alo", "binning"]
    kinds = ["state", "symbol"]
    
    # Create a figure with 4 subplots in a single row
    fig, axes = plt.subplots(1, 4, figsize=(FIG_WIDTH, FIG_HEIGHT))
    
    # Setup colors for different architectures
    arch_colors = {
        'lstm': '#1f77b4',       # blue 
        'transformer': '#ff7f0e'  # orange
    }
    
    # Process each kind and semiring combination
    subplot_idx = 0
    for kind_idx, kind in enumerate(kinds):
        for sem_idx, semiring in enumerate(semirings):
            print(f"Processing {kind}-{semiring} subplot...")
            ax = axes[subplot_idx]
            subplot_idx += 1
            
            # Filter data for current kind and semiring
            df_sem = df[(df.kind == kind) & (df.semiring == semiring)]
            
            if df_sem.empty:
                print(f"No data for {kind} and {semiring}")
                continue
            
            # Process each architecture
            for arch in architectures:
                # Filter for current architecture
                df_arch = df_sem[df_sem.arch == arch]
                
                if df_arch.empty:
                    continue
                
                # Process vanilla data
                df_vanilla = df_arch[df_arch.type == "vanilla"]
                if not df_vanilla.empty:
                    print(f"Processing {len(df_vanilla)} vanilla records for {kind}-{semiring}-{arch}")
                    
                    # Bin vanilla data by count
                    bin_size = BIN_SIZE
                    df_vanilla_positive = df_vanilla[df_vanilla.kl > 0]
                    
                    if not df_vanilla_positive.empty:
                        print(f"Found {len(df_vanilla_positive)} positive KL vanilla records")
                        
                        # Get range of counts
                        min_count = df_vanilla_positive['count'].min()
                        max_count = df_vanilla_positive['count'].max()
                        print(f"Count range: {min_count} - {max_count}")
                        
                        # Create bin edges and centers
                        bin_edges = np.arange(0, max_count + bin_size + 1, bin_size)
                        bin_centers = bin_edges[:-1] + bin_size / 2
                        
                        # Create bins and calculate statistics
                        binned_data = []
                        for i in range(len(bin_centers)):
                            lower = bin_edges[i]
                            upper = bin_edges[i+1]
                            
                            binned_df = df_vanilla_positive[(df_vanilla_positive['count'] >= lower) & 
                                                          (df_vanilla_positive['count'] < upper)]
                            
                            if not binned_df.empty:
                                binned_data.append({
                                    'count_bin': bin_centers[i],
                                    'kl_mean': binned_df['kl'].mean(),
                                    'kl_se': binned_df['kl'].sem() if len(binned_df) > 1 else 0,
                                    'num_points': len(binned_df)
                                })
                        
                        # Plot binned vanilla data
                        if binned_data:
                            print(f"Created {len(binned_data)} bins for plotting")
                            df_binned = pd.DataFrame(binned_data)

                            if arch == "lstm":
                                arch_label = "LSTM"
                            elif arch == "transformer":
                                arch_label = "Transformer"
                            
                            # Plot line - match style from example script
                            line = ax.plot(
                                df_binned['count_bin'],
                                df_binned['kl_mean'],
                                marker='s',  # square marker
                                linestyle='--',  # dashed line
                                label=f'{arch_label} Observed (binned)',
                                color=arch_colors[arch],
                                linewidth=2
                            )[0]
                            
                            # Add shaded area for standard error
                            ax.fill_between(
                                df_binned['count_bin'],
                                df_binned['kl_mean'] - df_binned['kl_se'],
                                df_binned['kl_mean'] + df_binned['kl_se'],
                                alpha=0.2,
                                color=arch_colors[arch]
                            )
                        else:
                            print(f"No bins created for {kind}-{semiring}-{arch}")
                    else:
                        print(f"No positive KL values found for {kind}-{semiring}-{arch} vanilla")
                else:
                    print(f"No vanilla data for {kind}-{semiring}-{arch}")
                
                # Process intervention data
                df_intervention = df_arch[(df_arch.type == "intervention") & (df_arch.intervention == kind)]
                if not df_intervention.empty:
                    print(f"Processing {len(df_intervention)} intervention records for {kind}-{semiring}-{arch}")
                    
                    # Group by intervention count
                    grouped = df_intervention.groupby('intervention_count').agg({
                        'kl': ['mean', 'sem']
                    }).reset_index()
                    
                    # Flatten the multi-index columns
                    grouped.columns = ['intervention_count', 'kl_mean', 'kl_sem']
                    
                    # Filter for positive values (for log scale)
                    positive_groups = grouped[grouped['kl_mean'] > 0]
                    
                    if not positive_groups.empty:
                        # Sort by intervention count
                        positive_groups = positive_groups.sort_values('intervention_count')

                        if arch == "lstm":
                            arch_label = "LSTM"
                        elif arch == "transformer":
                            arch_label = "Transformer"

                        # Plot intervention data as line - match style from example script
                        line = ax.plot(
                            positive_groups['intervention_count'],
                            positive_groups['kl_mean'],
                            marker='o',  # circle marker
                            linestyle='-',  # solid line
                            label=f'{arch_label} Intervention',
                            color=arch_colors[arch],
                            linewidth=2
                        )[0]
                        
                        # Add shaded area for standard error
                        ax.fill_between(
                            positive_groups['intervention_count'],
                            positive_groups['kl_mean'] - positive_groups['kl_sem'],
                            positive_groups['kl_mean'] + positive_groups['kl_sem'],
                            alpha=0.2,
                            color=arch_colors[arch]
                        )
            
            # Set subplot title based on kind and semiring
            title = f"{kind.capitalize()}, "
            title += "At-least-once" if semiring == "alo" else "Binning"
            ax.set_title(title)
            
            # Set labels for axes - match labels from example script
            # Only add y-axis label to the first subplot
            if sem_idx == 0 and kind_idx == 0:  # Only first subplot
                ax.set_ylabel('Decomposed KL Divergence')
            ax.set_xlabel('Occurrence count')
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            
            # Set x-axis limits based on semiring
            if semiring == "alo":
                ax.set_xlim(100, 1500)
            else:  # binning
                ax.set_xlim(100, 1250)
            
            # Add grid
            ax.grid(True, alpha=0.3, linestyle='--')
            
            # Make legend font size smaller to save space but ensure it's visible
            if sem_idx == 0 and kind_idx == 0:  # Only for first subplot
                # Remove duplicate labels in legend
                handles, labels = ax.get_legend_handles_labels()
                by_label = dict(zip(labels, handles))
                leg = ax.legend(by_label.values(), by_label.keys(), fontsize=18, 
                              loc='upper right', frameon=True)
    
    # Adjust layout with more space
    plt.tight_layout(pad=1.0, w_pad=0.5)
    
    # Save figure
    output_path = os.path.join(PLOTS_ROOT, "combined_state_symbol_kl_vs_count.png")
    fig.savefig(output_path, dpi=PLOT_DPI)
    print(f"Combined plot saved to {output_path}")
    plt.close(fig)

def main():
    """Main execution function"""
    print("Starting analysis...")
    
    # Load and process data
    print("Loading experiment data...")
    if os.path.exists(DF_STORE):
        print(f"Loading from {DF_STORE}...")
        df = pd.read_pickle(DF_STORE)
        print(f"Loaded DataFrame with shape: {df.shape}")
        
        # Check if the stored DataFrame has the expected structure
        # Look for 'element' column which should exist in the proper per-element structure
        if 'element' not in df.columns:
            print("Cached DataFrame appears to be using old aggregated format. Recreating...")
            # Remove the old cache file
            os.remove(DF_STORE)
            # Recreate the DataFrame with proper per-element structure
            df = load_experiment_data()
            os.makedirs(os.path.dirname(DF_STORE), exist_ok=True)
            df.to_pickle(DF_STORE)
        else:
            # Make sure semirings are correctly set if loading from cache
            df.loc[(df['type'] == 'vanilla') & (df['semiring'] == 'none'), 'semiring'] = 'alo'
    else:
        df = load_experiment_data()
        os.makedirs(os.path.dirname(DF_STORE), exist_ok=True)
        df.to_pickle(DF_STORE)
        print(f"Saved to {DF_STORE}")
    
    # Print summary statistics before filtering
    print(f"Total records before filtering: {len(df)}")
    print(f"Unique architectures: {df.arch.unique()}")
    print(f"Unique semirings: {df.semiring.unique()}")
    print(f"Number of vanilla records: {len(df[df.type == 'vanilla'])}")
    print(f"Number of intervention records: {len(df[df.type == 'intervention'])}")
    
    # Filter the DataFrame to select one random run from each configuration
    print("\nFiltering DataFrame to select one random run from each unique configuration...")
    
    # Define groupby columns for identifying unique configurations
    groupby_cols = ['arch', 'type', 'semiring', 'seed', 'intervention', 'target', 
                    'intervention_count', 'kind', 'element']

    # Create a new DataFrame to hold the filtered results
    filtered_df = pd.DataFrame()
    selected_rows = []
    # Group the data and select one random run from each group
    random_state = np.random.RandomState(42)  # Use fixed seed for reproducibility
    for name, group in tqdm(df.groupby(groupby_cols)):
        if len(group) > 0:
            selected_rows.append(group.iloc[0])
            # # Select a random run from this group
            # random_idx = 0 # random_state.randint(0, len(group))
            # random_run = group.iloc[[random_idx]]
            
            # # Add the random run to the filtered DataFrame
            # filtered_df = pd.concat([filtered_df, random_run])
    
    filtered_df = pd.DataFrame(selected_rows) 
    # Reset index for the filtered DataFrame
    filtered_df = filtered_df.reset_index(drop=True)
    
    print(f"Original DataFrame shape: {df.shape}")
    print(f"Filtered DataFrame shape: {filtered_df.shape}")
    
    # Print counts by type, semiring, and kind after filtering
    print("\nRecord counts by type, semiring, and kind after filtering:")
    for typ in ["vanilla", "intervention"]:

        for semiring in ["alo", "binning"]:
            for kind in ["state", "symbol"]:
                count = len(filtered_df[(filtered_df.type == typ) & 
                                     (filtered_df.semiring == semiring) & 
                                     (filtered_df.kind == kind)])
                unique_elements = len(filtered_df[(filtered_df.type == typ) & 
                                             (filtered_df.semiring == semiring) & 
                                             (filtered_df.kind == kind)].element.unique())
                print(f"{typ}-{semiring}-{kind}: {count} records for {unique_elements} unique elements")
    
    # Use the filtered DataFrame for plotting
    create_main_plots(filtered_df)
    create_combined_plot(filtered_df)
    
    print(f"Analysis complete. Plots saved to {PLOTS_ROOT}")
    breakpoint()
    
if __name__ == "__main__":
    main()
