#!/usr/bin/env python3
import json
import glob
import os
import sys
import re
import random
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.font_manager as fm
from matplotlib.gridspec import GridSpec

from tqdm import tqdm

from intervention_sampling.automata_register import AUTOMATA_REGISTER

postfix = "_500_0.1_2_1000_100"

BIN_SIZE = 150
NUM_SAMPLES = 500

# Directory paths
MODELS_ROOT = "experiments_parityfree_intervention/models" + postfix
DATASETS_ROOT = f"experiments_parityfree_intervention/data{postfix}/datasets"
AUTOMATON = "parity_free_hp"
BASE_DATA = os.path.join(DATASETS_ROOT, AUTOMATON)
PLOTS_ROOT = "experiments_parityfree_intervention/plots" + postfix
DECOMPOSED_ROOT = os.path.join(PLOTS_ROOT, "decomposed")
COMBINED_ROOT = os.path.join(PLOTS_ROOT, "combined")

# Plot configuration
PLOT_DPI = 300  # Higher resolution for publications
FIG_WIDTH = 24
FIG_HEIGHT = 8

HASH = MODELS_ROOT.replace("/","__")
DF_STORE = f"df_storage/{HASH}.bu2"

# Create output directories
for directory in [PLOTS_ROOT, DECOMPOSED_ROOT, COMBINED_ROOT]:
    os.makedirs(directory, exist_ok=True)

import os
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import matplotlib as mpl
import pathlib

# Path to your fonts
font_dir = os.path.expanduser('~/.fonts/')

# Font files
times_regular = os.path.join(font_dir, 'Times.TTF')
times_bold = os.path.join(font_dir, 'Timesbd.TTF')
times_italic = os.path.join(font_dir, 'Timesi.TTF')
times_bold_italic = os.path.join(font_dir, 'Timesbi.TTF')

# Verify fonts exist
font_files = [times_regular, times_bold, times_italic, times_bold_italic]
for font_file in font_files:
    if not os.path.exists(font_file):
        print(f"Warning: Font file not found: {font_file}")

# Clear existing font cache (in case you've tried to add these fonts before)
from matplotlib.font_manager import findfont, FontProperties
findfont(FontProperties(), rebuild_if_missing=True)

# Manually add font files to matplotlib's font manager
for font_file in font_files:
    if os.path.exists(font_file):
        fm.fontManager.addfont(font_file)

# Get all font family names to verify Times is now available
font_names = sorted([f.name for f in fm.fontManager.ttflist])
print("Available font families:")
for name in font_names:
    if 'time' in name.lower():
        print(f" - {name}")
#sns.set(style="whitegrid", context="talk", font="Times New Roman")

# Set Times as the default font family globally
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['font.serif'] = ['Times New Roman']
plt.rcParams['font.size'] = 26        # Base font size (minimum size requested)
plt.rcParams['axes.titlesize'] = 35   # Subplot title size (scaled proportionally)
plt.rcParams['axes.labelsize'] = 32   # Axis label size (scaled proportionally)
plt.rcParams['xtick.labelsize'] = 26  # X-axis tick label size (minimum size)
plt.rcParams['ytick.labelsize'] = 26  # Y-axis tick label size (minimum size)
plt.rcParams['legend.fontsize'] = 26  # Legend font size (scaled proportionally)
plt.rcParams['figure.titlesize'] = 40 # Figure title size (scaled proportionally)
# # Set up Times New Roman font
# plt.rcParams['font.family'] = 'serif'
# plt.rcParams['font.serif'] = ['Times New Roman']

# Regex pattern to extract metadata from file paths
PATH_PATTERN = re.compile(
    r".*/"  # Any directories leading up to models
    r"(?P<arch>lstm|transformer)/"  # Architecture
    r"(?P<type>vanilla|intervention)/"  # Model type
    r"(?P<semiring>none|binning|alo)/"  # Semiring type
    r"(?P<am_idx>\d+)/"  # Automaton index
    r"(?P<intervention>none|state|symbol)/"  # Intervention type
    r"(?P<target>none|\d+)/"  # Target, can be 'none' for vanilla models
    r"train/(?P<ic>\d+)/"  # Intervention count or seed for vanilla
    r"(?P<mseed>\d+)/eval/decomposed_kls\.json$"  # Model seed
)


def set_consistent_y_range(fig, axes):
    """
    Set a consistent y-axis range across all subplots in a figure.
    Also sets y-axis to log scale.
    
    Args:
        fig: The matplotlib figure
        axes: Array of axes in the figure
    """
    # Make sure axes is iterable
    if not hasattr(axes, '__iter__'):
        axes = [axes]
    elif isinstance(axes, np.ndarray):
        axes = axes.flatten()
    
    # Find global min and max y values
    y_min = float('inf')
    y_max = float('-inf')
    
    # Iterate through all axes to find min and max
    for ax in axes:
        if not ax.get_visible():
            continue
            
        # Check all line objects
        for line in ax.get_lines():
            data = line.get_ydata()
            if len(data) > 0:
                data = np.array(data)
                # Filter for positive values for log scale
                positive_data = data[data > 0]
                if len(positive_data) > 0:
                    current_min = np.min(positive_data)
                    current_max = np.max(data)
                    
                    y_min = min(y_min, current_min)
                    y_max = max(y_max, current_max)
        
        # Check collections for error bars and fills
        for collection in ax.collections:
            # Extract y values from paths
            for path in collection.get_paths():
                if len(path.vertices) > 0:
                    y_values = path.vertices[:, 1]
                    positive_y = y_values[y_values > 0]
                    if len(positive_y) > 0:
                        y_min = min(y_min, positive_y.min())
                        y_max = max(y_max, y_values.max())
    
    # Handle edge cases
    if y_min == float('inf') or y_max == float('-inf'):
        # No valid data found, use defaults
        y_min = 0.001
        y_max = 1.0
        return
    
    # Set a minimum ratio between max and min for log scale
    if y_max / y_min < 10:
        y_min = y_max / 10
    
    # Add some padding
    y_min = max(1e-10, y_min * 0.9)  # Ensure we don't go too close to zero
    y_max = y_max * 1.1
    
    # Set log scale and limits for all visible axes
    for ax in axes:
        if not ax.get_visible():
            continue
        
        y_min = max(0.000001, y_min)

        ax.set_yscale('log')
        ax.set_ylim(y_min, y_max)

        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)


def parse_metadata(file_path):
    """
    Parse metadata from a file path using regex.
    
    Args:
        file_path: Path to the data file
        
    Returns:
        Dictionary of metadata or None if path doesn't match pattern
    """
    match = PATH_PATTERN.match(file_path)
    if not match:
        print(f"Failed to match path: {file_path}")
        return None
        
    metadata = match.groupdict()
    
    # Convert numeric fields to integers
    for key in ("am_idx", "ic", "mseed"):
        metadata[key] = int(metadata[key])
    
    # Handle target conversion - could be 'none' for vanilla models
    if metadata["target"] != "none":
        metadata["target"] = int(metadata["target"])
    
    # For vanilla models, rename ic to intervention_count for consistency
    if metadata["type"] == "vanilla":
        metadata["seed"] = metadata["ic"]
        metadata["intervention_count"] = 0
    else:
        metadata["intervention_count"] = metadata["ic"]
    
    return metadata


def get_counts(metadata, semiring):
    """
    Get occurrence counts for states and symbols from the arcs file.
    
    Args:
        metadata: Dictionary of metadata from parse_metadata
        semiring: Semiring type ('binning' or 'alo')
        
    Returns:
        Tuple of (state_counter, symbol_counter)
    """
    am_idx = metadata["am_idx"]
    
    # For vanilla models, use the seed directly from metadata
    if metadata["type"] == "vanilla":
        seed_val = metadata["seed"]
    else:
        # Default seed for intervention models
        seed_val = 1
    
    # Path to the arcs file
    if metadata["type"] == "vanilla":
        arcs_path = os.path.join(BASE_DATA, "vanilla", "train", str(am_idx), "arcs.txt")
    else:
        # For intervention models, construct path based on metadata
        base_path = os.path.join(
            BASE_DATA, 
            metadata["semiring"], 
            AUTOMATON,
            str(am_idx), 
            metadata["intervention"], 
            str(metadata["target"]), 
            "train"
        )
        arcs_path = os.path.join(base_path, "arcs.txt")
    
    # Fall back to training directory if test arcs.txt doesn't exist
    if not os.path.exists(arcs_path):
        # For vanilla
        if metadata["type"] == "vanilla":
            arcs_path = os.path.join(BASE_DATA, "vanilla", "train", str(am_idx), str(seed_val), "arcs.txt")
        else:
            # For intervention
            base_path = os.path.join(
                BASE_DATA, 
                metadata["semiring"], 
                AUTOMATON,
                str(am_idx), 
                metadata["intervention"], 
                str(metadata["target"]), 
                "train",
                str(metadata["intervention_count"])
            )
            arcs_path = os.path.join(base_path, "arcs.txt")
    
    # Check if file exists
    if not os.path.exists(arcs_path):
        print(f"Arc file not found: {arcs_path}")
        return Counter(), Counter()
    
    # Initialize counters for states and symbols
    state_counter = Counter()
    symbol_counter = Counter()
    
    # Parse the arc file
    with open(arcs_path) as f:
        for line in f:
            try:
                transitions = eval(line.strip())
                
                # Process transitions based on semiring type
                if semiring == "binning":
                    # For binning semiring, count each occurrence
                    for src, tgt, sym in transitions:
                        state_counter[src] += 1
                        symbol_counter[sym] += 1
                        
                    # Count final state
                    if transitions:
                        state_counter[transitions[-1][1]] += 1
                else:
                    # For other semirings, count unique occurrences
                    seen_states = set()
                    seen_symbols = set()
                    
                    for src, tgt, sym in transitions:
                        seen_states.add(src)
                        seen_symbols.add(sym)
                    
                    # Add final state
                    if transitions:
                        seen_states.add(transitions[-1][1])
                    
                    # Update counters
                    for state in seen_states:
                        state_counter[state] += 1
                    for symbol in seen_symbols:
                        symbol_counter[symbol] += 1
            except Exception as e:
                print(f"Error parsing line in {arcs_path}: {e}")
                continue
    return state_counter, symbol_counter


def load_experiment_data():
    """
    Load and process all experiment data from JSON files.
    
    Returns:
        Pandas DataFrame with processed records
    """
    records = []
    vanilla_files_processed = 0
    intervention_files_processed = 0
    
    # Find all result JSON files
    search_pattern = os.path.join(MODELS_ROOT, "**", "eval", "decomposed_kls.json")
    for file_path in tqdm(glob.glob(search_pattern, recursive=True)):
        # Parse metadata from file path
        metadata = parse_metadata(file_path)
        if not metadata:
            continue
        
        # Load KL divergence data
        with open(file_path) as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError:
                print(f"Failed to parse JSON from {file_path}")
                continue
        
        state_contributions = data.get("state_contributions", {})
        symbol_contributions = data.get("symbol_contributions", {})
        total_kl = data.get("total_kl")
        
        # Process vanilla model data
        if metadata["type"] == "vanilla":
            vanilla_files_processed += 1
          
            # We need to loop over both semirings since their counting differs
            for semiring in ("binning", "alo"):
                # Get counts for this model
                state_counts, symbol_counts = get_counts(metadata, semiring)

                # Create a record for each state contribution
                for state_id, kl in state_contributions.items():
                    element_id = int(state_id)
                    records.append({
                        **metadata,
                        "semiring": semiring,
                        "kind": "state",
                        "element": element_id,
                        "kl": kl,
                        "count": state_counts.get(element_id, 0),
                        "total_kl": total_kl,
                    })

                # Create a record for each symbol contribution
                for symbol_id, kl in symbol_contributions.items():
                    # Convert symbol_id to int if possible
                    try:
                        if symbol_id.isdigit():
                            element_id = int(symbol_id)
                        else:
                            element_id = symbol_id
                    except (ValueError, AttributeError):
                        element_id = symbol_id
                    
                    records.append({
                        **metadata,
                        "semiring": semiring,
                        "kind": "symbol",
                        "element": element_id,
                        "kl": kl,
                        "count": symbol_counts.get(element_id, 0),
                        "total_kl": total_kl,
                    })
                 
        # Process intervention model data
        else:
            intervention_files_processed += 1

            if metadata["intervention"] == "state":
                kl = data["state_contributions"][str(metadata["target"])]
            elif metadata["intervention"] == "symbol":
                kl = data["symbol_contributions"][str(metadata["target"])]
            
            records.append({
                **metadata,
                "kind": metadata["intervention"],
                "element": metadata["target"],
                "kl": kl,
                "count": metadata["intervention_count"],
                "total_kl": total_kl,
            })
    
    # Check if we found any records
    if not records:
        raise RuntimeError("No records found after processing JSON files. Check file paths and patterns.")
    
    print(f"Processed {vanilla_files_processed} vanilla files and {intervention_files_processed} intervention files")
    
    # Convert to DataFrame
    return pd.DataFrame(records)


def create_combined_architecture_plots(df):
    """
    Create plots comparing different architectures side by side.
    
    Each plot consists of 1x3 subplots showing:
    1) Intervention data: mean decomposed KL with std error for both LSTM and Transformer
    2) Vanilla data: binned mean decomposed KL with std error for both LSTM and Transformer
    
    Args:
        df: DataFrame with all experiment data
    """
    print("Creating combined architecture plots...")
    
    # Get all architectures and semirings
    architectures = sorted(df.arch.unique())
    semirings = ["binning", "alo"]
    
    for semiring in semirings:
        print(f"  Processing {semiring} semiring...")
        
        # Filter data for current semiring
        df_semiring = df[df.semiring == semiring]
        
        if df_semiring.empty:
            print(f"    No data for {semiring} semiring")
            continue
        
        # Get unique kinds (state/symbol)
        kinds = sorted(df_semiring.kind.unique())
        
        for kind in kinds:
            # Filter data for current kind
            df_kind = df_semiring[df_semiring.kind == kind]
            
            if df_kind.empty:
                print(f"    No data for {semiring} semiring and {kind} kind")
                continue
            
            # Get unique elements
            elements = []
            for elem in df_kind.element.unique():
                # Skip the start state (state 0) for state plots
                if kind == "state" and (elem == 0 or elem == "0"):
                    continue
                elements.append(elem)
            
            # Sort elements consistently
            elements = sorted(elements, key=lambda x: str(x))
            
            # Get labels from automata register if available
            element_labels = {}
            if AUTOMATON in AUTOMATA_REGISTER:
                automaton = AUTOMATA_REGISTER[AUTOMATON]
                if kind == "state" and hasattr(automaton, 'state_map'):
                    element_labels = automaton.state_map
                elif kind == "symbol" and hasattr(automaton, 'symbol_map'):
                    element_labels = automaton.symbol_map
            
            # Use 1x3 grid layout with multiple figures if needed
            n_rows = 1
            n_cols = 3
            n_plots = n_rows * n_cols
            n_elements = len(elements)
            
            # Calculate how many figures we need
            n_figures = (n_elements + n_plots - 1) // n_plots
            
            for fig_idx in range(n_figures):
                # Create figure
                fig, axes = plt.subplots(n_rows, n_cols, figsize=(FIG_WIDTH, FIG_HEIGHT))
                
                # Get elements for this figure
                start_idx = fig_idx * n_plots
                end_idx = min(start_idx + n_plots, n_elements)
                fig_elements = elements[start_idx:end_idx]
                
                # Flatten axes for easier indexing
                axes_flat = axes.flatten()
                
                # Create subplots for each element
                for i, element in enumerate(fig_elements):
                    ax = axes_flat[i]
                    
                
                    
                    # Get element label
                    element_label = element_labels.get(element, f"{kind.capitalize()} {element}")
                    element_label = f"{element_label.capitalize()}"

                    # Setup colors for different architectures
                    arch_colors = {
                        'lstm': '#1f77b4',       # blue 
                        'transformer': '#ff7f0e'  # orange
                    }
                    
                    # Process each architecture
                    for arch_idx, arch in enumerate(architectures):
                        # Filter for current architecture, element and kind
                        df_arch_elem = df_kind[(df_kind.arch == arch) & 
                                              ((df_kind.element == element) | 
                                               (df_kind.element.astype(str) == str(element)))]
                        
                        if df_arch_elem.empty:
                            continue
                        
                        # Split into vanilla and intervention
                        df_vanilla = df_arch_elem[df_arch_elem.type == "vanilla"]
                        df_intervention = df_arch_elem[df_arch_elem.type == "intervention"]
                        
                        # 1. Plot intervention data
                        df_targeted = df_intervention[
                            (df_intervention.intervention == kind) & 
                            ((df_intervention.target == element) | 
                             (df_intervention.target.astype(str) == str(element)))
                        ]
                        
                        if not df_targeted.empty:
                            # Group by intervention count
                            grouped = df_targeted.groupby('intervention_count')['kl'].agg(['mean', 'sem']).reset_index()
                            
                            # Filter for positive values (for log scale)
                            positive_groups = grouped[grouped['mean'] > 0]
                            if not positive_groups.empty:
                                # Plot intervention data as line

                                if arch == "lstm":
                                    arch_name = "LSTM"
                                elif arch == "transformer":
                                    arch_name = "Transformer"

                                line = ax.plot(
                                    positive_groups['intervention_count'],
                                    positive_groups['mean'],
                                    marker='o',
                                    label=f'{arch_name} Intervention',
                                    color=arch_colors[arch],
                                    linewidth=2
                                )[0]
                                
                                # Add shaded area for standard error
                                ax.fill_between(
                                    positive_groups['intervention_count'],
                                    positive_groups['mean'] - positive_groups['sem'],
                                    positive_groups['mean'] + positive_groups['sem'],
                                    alpha=0.2,
                                    color=arch_colors[arch]
                                )
                        
                        # 2. Plot vanilla data (binned by count)
                        if not df_vanilla.empty:
                            # Only include positive kl values
                            df_vanilla_positive = df_vanilla[df_vanilla.kl > 0]
                            
                            if not df_vanilla_positive.empty:
                                # Bin the data by counts
                                bin_size = BIN_SIZE
                                
                                # Get range of counts
                                min_count = df_vanilla_positive['count'].min()
                                max_count = df_vanilla_positive['count'].max()
                                
                                # Create bin centers
                                bin_centers = np.arange(0, max_count + bin_size, bin_size)
                                
                                # Function to bin data around center
                                def bin_around_count(df, center_count, width=bin_size):
                                    lower = max(0, center_count - width//2)
                                    upper = center_count + width//2
                                    return df[(df['count'] >= lower) & (df['count'] <= upper)]
                                
                                # Bin data
                                binned_data = []
                                for center in bin_centers:
                                    binned_df = bin_around_count(df_vanilla_positive, center)
                                    
                                    if not binned_df.empty:
                                        binned_data.append({
                                            'count_bin': center,
                                            'kl_mean': binned_df['kl'].mean(),
                                            'kl_se': binned_df['kl'].sem() if len(binned_df) > 1 else 0,
                                            'num_points': len(binned_df)
                                        })
                                
                                # Plot binned vanilla data
                                if binned_data:
                                    df_binned = pd.DataFrame(binned_data)
                                    
                                    if arch == "lstm":
                                        arch_name = "LSTM"
                                    elif arch == "transformer":
                                        arch_name = "Transformer"

                                    # Plot line
                                    line = ax.plot(
                                        df_binned['count_bin'],
                                        df_binned['kl_mean'],
                                        marker='s',  # square marker to distinguish from intervention
                                        linestyle='--',  # dashed line to distinguish from intervention
                                        label=f'{arch_name} Observed (binned)',
                                        color=arch_colors[arch],
                                        linewidth=2
                                    )[0]
                                    
                                    # Add shaded area for standard error
                                    ax.fill_between(
                                        df_binned['count_bin'],
                                        df_binned['kl_mean'] - df_binned['kl_se'],
                                        df_binned['kl_mean'] + df_binned['kl_se'],
                                        alpha=0.2,
                                        color=arch_colors[arch]
                                    )
                    
                    # Set subplot title
                    ax.set_title(element_label)
                    
                    # Set labels for axes
                    ax.set_ylabel('Decomposed KL Divergence')
                    ax.set_xlabel('Occurrence count')
                    
                    # Set x-axis limits based on semiring
                    if semiring == "alo":
                        ax.set_xlim(0, 400)
                    else:  # binning
                        ax.set_xlim(0, 4*NUM_SAMPLES)
                    
                    # Add grid
                    ax.grid(True, alpha=0.3, linestyle='--')
                    
                    # Add legend only to first subplot to avoid clutter
                    if i == 0:
                        # Remove duplicate labels in legend
                        handles, labels = ax.get_legend_handles_labels()
                        by_label = dict(zip(labels, handles))
                        ax.legend(by_label.values(), by_label.keys()) #, fontsize='small')
                
                # Hide empty subplots
                for i in range(len(fig_elements), n_plots):
                    axes_flat[i].set_visible(False)
                
                # Set consistent y-axis range and log scale
                set_consistent_y_range(fig, axes_flat)
                
                # No main title as per requirement
                # fig.suptitle(f"{semiring} {kind} KL vs. Count - All Architectures - Group {fig_idx+1}/{n_figures}")
                
                # Adjust layout
                plt.tight_layout()
                
                # Save figure
                output_path = os.path.join(
                    PLOTS_ROOT, 
                    f"all_arch_{semiring}_{kind}_combined_group{fig_idx+1}.png"
                )
                fig.savefig(output_path, dpi=PLOT_DPI)
                plt.close(fig)


def main():
    """Main execution function"""
    print("Starting analysis...")
    
    # Load and process data
    print("Loading experiment data...")
    if os.path.exists(DF_STORE):
        print(f"Loading from {DF_STORE}...")
        df = pd.read_pickle(DF_STORE)
    else:
        df = load_experiment_data()
        df.to_pickle(DF_STORE)
        print(f"Saved to {DF_STORE}")

    # Define groupby columns
    groupby_cols = ['arch', 'type', 'semiring', 'am_idx', 'intervention', 'target', 
                    'ic', 'intervention_count', 'kind', 'element']
    
    # Filter for groups with at least 10 training runs and sample exactly 10 runs
    filtered_groups = []
    
    # Create a new DataFrame to hold the filtered results
    filtered_df = pd.DataFrame()
    
    # Group the data
    for name, group in df.groupby(groupby_cols):
        # Extract am_idx from the group name (9th position in the groupby_cols)
        am_idx = name[3]  # am_idx is at index 3 in groupby_cols
        
        # Only include groups with at least 10 runs
        if len(group) >= 10:
            # If we have more than 10 runs, sample exactly 10
            if len(group) > 10:
                sampled_group = group.sample(n=10, random_state=42)
            else:
                sampled_group = group
                
            # Get the index of the run with minimum total_kl
            min_kl_idx = sampled_group['total_kl'].idxmin()
            best_run = df.loc[min_kl_idx:min_kl_idx]
            
            # Add the best run to our filtered DataFrame
            filtered_df = pd.concat([filtered_df, best_run])
    
    # Reset index for the filtered DataFrame
    filtered_df = filtered_df.reset_index(drop=True)
    print(f"Original DataFrame shape: {df.shape}")
    print(f"Filtered DataFrame shape: {filtered_df.shape}")

    # Create combined architecture plots (comparing LSTM and Transformer)
    create_combined_architecture_plots(filtered_df)
    
    print(f"Analysis complete. Plots saved to {PLOTS_ROOT}")

if __name__ == "__main__":
    main()