#!/usr/bin/env python3
import os
import json
import argparse
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
from collections import defaultdict
from pathlib import Path

from intervention_sampling.automata_register import AUTOMATA_REGISTER


def load_kl_data(base_dir, intervention_start, intervention_end, intervention_step):
    """Load all KL decomposition data from the evaluation directories."""
    
    # Store data for each metric
    all_data = {
        'intervention_counts': [],
        'total_kl': [],
        'states': defaultdict(list),
        'symbols': defaultdict(list),
        'transitions': defaultdict(list)
    }
    
    # Collect all state, symbol, and transition keys to ensure consistent plotting
    all_states = set()
    all_symbols = set()
    all_transitions = set()
    
    # Iterate through all intervention counts
    for i in range(intervention_start, intervention_end + 1, intervention_step):
        eval_dir = os.path.join(base_dir, f"transformer_{i}", "eval")
        kl_file = os.path.join(eval_dir, "decomposed_kls.json")
        
        if os.path.exists(kl_file):
            with open(kl_file, 'r') as f:
                data = json.load(f)
            
            # Store basic data
            all_data['intervention_counts'].append(i)
            all_data['total_kl'].append(data['total_kl'])
            
            # Store state contributions
            for state, value in data['state_contributions'].items():
                all_states.add(state)
                all_data['states'][state].append(value)
            
            # Store symbol contributions
            for symbol, value in data['symbol_contributions'].items():
                all_symbols.add(symbol)
                all_data['symbols'][symbol].append(value)
            
            # Store transition contributions
            for transition, value in data['transition_contributions'].items():
                all_transitions.add(transition)
                all_data['transitions'][transition].append(value)
        else:
            print(f"Warning: No KL data found for intervention count {i}")
    
    # Ensure all lists have the same length by filling in missing values with NaN
    for state in all_states:
        if len(all_data['states'][state]) < len(all_data['intervention_counts']):
            all_data['states'][state].extend([np.nan] * (len(all_data['intervention_counts']) - len(all_data['states'][state])))
    
    for symbol in all_symbols:
        if len(all_data['symbols'][symbol]) < len(all_data['intervention_counts']):
            all_data['symbols'][symbol].extend([np.nan] * (len(all_data['intervention_counts']) - len(all_data['symbols'][symbol])))
    
    for transition in all_transitions:
        if len(all_data['transitions'][transition]) < len(all_data['intervention_counts']):
            all_data['transitions'][transition].extend([np.nan] * (len(all_data['intervention_counts']) - len(all_data['transitions'][transition])))
    
    return all_data, all_states, all_symbols, all_transitions

def set_plot_style():
    """Configure the plot style to use Times New Roman and other aesthetics."""
    plt.rcParams.update({
        'font.family': 'serif',
        'font.serif': ['Times New Roman'],
        'font.size': 12,
        'axes.titlesize': 14,
        'axes.labelsize': 12,
        'xtick.labelsize': 10,
        'ytick.labelsize': 10,
        'legend.fontsize': 10,
        'figure.titlesize': 16
    })
    
    sns.set_style("ticks")  # Uses seaborn's despine functionality

def plot_total_kl(data, output_dir, title):
    """Plot the total KL divergence as a function of intervention count."""
    fig, ax = plt.subplots(figsize=(8, 6))
    
    sns.lineplot(
        x=data['intervention_counts'], 
        y=data['total_kl'],
        marker='o',
        markersize=8,
        linewidth=2,
        color='#1f77b4',
        ax=ax
    )
    
    ax.set_xlabel('Intervention Count')
    ax.set_ylabel('Total KL Divergence')
    ax.set_title(f'Total KL Divergence vs. Intervention Count - {title}')
    ax.grid(True, linestyle='--', alpha=0.7)
    sns.despine()
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    plt.savefig(os.path.join(output_dir, 'total_kl_plot.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'total_kl_plot.pdf'))
    plt.close()

def plot_state_contributions(data, states, output_dir, state_map=None, main_title=""):
    """Plot the state contributions as subplots with consistent y-axis scale.
    
    Args:
        data: Dictionary containing KL data
        states: Set of state identifiers
        output_dir: Directory to save plots
        state_map: Optional dictionary mapping state IDs to state names
    """
    n_states = len(states)
    
    # Determine subplot layout based on number of states
    if n_states <= 3:
        n_cols = n_states
        n_rows = 1
    else:
        n_cols = 3
        n_rows = (n_states + 2) // 3  # Ceiling division
    
    # Find global y-axis limits for consistent scaling
    y_min = float('inf')
    y_max = float('-inf')
    
    # Sort states for consistent ordering
    sorted_states = sorted(states, key=lambda x: int(x) if x.isdigit() else float('inf'))
    
    # Determine global y-axis limits
    for state in sorted_states:
        values = [v for v in data['states'][state] if not np.isnan(v)]
        if values:
            y_min = min(y_min, min(values))
            y_max = max(y_max, max(values))
    
    # Add a small buffer to the limits for visual clarity
    y_range = y_max - y_min
    y_min = y_min - 0.05 * y_range if y_range > 0 else y_min * 0.95
    y_max = y_max + 0.05 * y_range if y_range > 0 else y_max * 1.05
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), sharex=True, sharey=True)
    
    # Convert to numpy array for easier indexing
    if n_rows == 1 and n_cols == 1:
        axes = np.array([axes])
    elif n_rows == 1 or n_cols == 1:
        axes = np.array(axes).reshape(-1)
    
    for i, state in enumerate(sorted_states):
        row, col = i // n_cols, i % n_cols
        ax = axes.flat[i]
        
        sns.lineplot(
            x=data['intervention_counts'],
            y=data['states'][state],
            marker='o',
            markersize=6,
            linewidth=2,
            ax=ax
        )
        
        # Use state_map for title if available
        if state_map and int(state) in state_map:
            title = f'State {state}: {state_map[int(state)]}'
        else:
            title = f'State {state}'
            
        ax.set_title(title)
        ax.set_xlabel('Intervention Count')
        ax.set_ylabel('KL Contribution')
        ax.set_ylim(y_min, y_max)
        ax.grid(True, linestyle='--', alpha=0.7)
        sns.despine(ax=ax)

    # Hide unused subplots
    for j in range(i+1, len(axes)):
        axes.flat[j].set_visible(False)
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    fig.suptitle(main_title)
    plt.savefig(os.path.join(output_dir, 'state_contributions_plot.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'state_contributions_plot.pdf'))
    plt.close()

def plot_symbol_contributions(data, symbols, output_dir, symbol_map=None, main_title=""):
    """Plot the symbol contributions as subplots with consistent y-axis scale.
    
    Args:
        data: Dictionary containing KL data
        symbols: Set of symbol identifiers
        output_dir: Directory to save plots
        symbol_map: Optional dictionary mapping symbol IDs to symbol names
    """
    n_symbols = len(symbols)
    
    # Determine subplot layout
    if n_symbols <= 3:
        n_cols = n_symbols
        n_rows = 1
    else:
        n_cols = 3
        n_rows = (n_symbols + 2) // 3
    
    # Sort symbols for consistent ordering (putting <EOS> last)
    sorted_symbols = sorted(symbols, key=lambda x: float('inf') if x == "<EOS>" else int(x) if x.isdigit() else x)
    
    # Find global y-axis limits for consistent scaling
    y_min = float('inf')
    y_max = float('-inf')
    
    # Determine global y-axis limits
    for symbol in sorted_symbols:
        values = [v for v in data['symbols'][symbol] if not np.isnan(v)]
        if values:
            y_min = min(y_min, min(values))
            y_max = max(y_max, max(values))
    
    # Add a small buffer to the limits for visual clarity
    y_range = y_max - y_min
    y_min = y_min - 0.05 * y_range if y_range > 0 else y_min * 0.95
    y_max = y_max + 0.05 * y_range if y_range > 0 else y_max * 1.05
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows), sharex=True, sharey=True)
    
    # Convert to numpy array for easier indexing
    if n_rows == 1 and n_cols == 1:
        axes = np.array([axes])
    elif n_rows == 1 or n_cols == 1:
        axes = np.array(axes).reshape(-1)
    
    for i, symbol in enumerate(sorted_symbols):
        ax = axes.flat[i]
        
        sns.lineplot(
            x=data['intervention_counts'],
            y=data['symbols'][symbol],
            marker='o',
            markersize=6,
            linewidth=2,
            ax=ax
        )
        
        # Use symbol_map for title if available
        symbol_alt = symbol
        if symbol.isnumeric():
            symbol_alt = int(symbol)
        if symbol_map and symbol_alt in symbol_map:
            title = f'Symbol {symbol}: {symbol_map[symbol_alt]}'
        else:
            title = f'Symbol {symbol}'
            
        ax.set_title(title)
        ax.set_xlabel('Intervention Count')
        ax.set_ylabel('KL Contribution')
        ax.set_ylim(y_min, y_max)
        ax.grid(True, linestyle='--', alpha=0.7)
        sns.despine(ax=ax)
    
    # Hide unused subplots
    for j in range(i+1, len(axes)):
        axes.flat[j].set_visible(False)
    
    fig.suptitle(main_title)
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    plt.savefig(os.path.join(output_dir, 'symbol_contributions_plot.png'), dpi=300)
    plt.savefig(os.path.join(output_dir, 'symbol_contributions_plot.pdf'))
    plt.close()

def plot_transition_contributions(data, transitions, output_dir, state_map=None, symbol_map=None, main_title=""):
    """Plot the transition contributions as subplots with consistent y-axis scale.
    
    Args:
        data: Dictionary containing KL data
        transitions: Set of transition identifiers
        output_dir: Directory to save plots
        state_map: Optional dictionary mapping state IDs to state names
        symbol_map: Optional dictionary mapping symbol IDs to symbol names
    """
    n_transitions = len(transitions)
    
    # Determine optimal layout based on number of transitions
    if n_transitions <= 3:
        # For 1-3 transitions, use a single row
        n_cols = n_transitions
        n_rows = 1
        figsize = (5*n_cols, 5)  # Wider, single-row figure
    else:
        # For many transitions, use a grid layout
        n_cols = 3
        n_rows = min(4, (n_transitions + 2) // 3)
        figsize = (15, 12)
    
    max_transitions = n_rows * n_cols
    
    # If we have too many transitions, create multiple figures
    n_figures = (n_transitions + max_transitions - 1) // max_transitions
    
    # Sort transitions for consistent ordering
    sorted_transitions = sorted(transitions)
    
    # Find global y-axis limits for all transitions
    y_min = float('inf')
    y_max = float('-inf')
    
    for transition in sorted_transitions:
        values = [v for v in data['transitions'][transition] if not np.isnan(v)]
        if values:
            y_min = min(y_min, min(values))
            y_max = max(y_max, max(values))
    
    # Add a small buffer to the limits for visual clarity
    y_range = y_max - y_min
    y_min = y_min - 0.05 * y_range if y_range > 0 else y_min * 0.95
    y_max = y_max + 0.05 * y_range if y_range > 0 else y_max * 1.05
    
    for fig_idx in range(n_figures):
        fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, sharex=True, sharey=True)
        
        # Handle different subplot array shapes
        if n_rows == 1 and n_cols == 1:
            axes = np.array([axes])
        elif n_rows == 1:
            axes = np.array(axes).reshape(1, -1)
        
        # Convert to flattened array for easier indexing
        axes_flat = axes.flatten()
        
        start_idx = fig_idx * max_transitions
        end_idx = min((fig_idx + 1) * max_transitions, n_transitions)
        
        for i, transition_idx in enumerate(range(start_idx, end_idx)):
            transition = sorted_transitions[transition_idx]
            ax = axes_flat[i]
            
            sns.lineplot(
                x=data['intervention_counts'],
                y=data['transitions'][transition],
                marker='o',
                markersize=6,
                linewidth=2,
                ax=ax
            )

            # Format transition title with state and symbol maps if available
            # Transitions are typically in format "state1:symbol>state2"
            if ">" in transition:
                # Parse transition string
                #state_part, next_part = transition.split(":")
                #symbol, next_state = next_part.split(">")

                tspl = transition.split("-")
                state_part = tspl[0]
                next_state = tspl[2][1:]
                symbol = tspl[1]
                
                # Apply maps if available
                symbol_alt = symbol
                if symbol.isnumeric():
                    symbol_alt = int(symbol)
                next_state_alt = ""
                if next_state.isnumeric():
                    next_state_alt = int(next_state)

                if state_map:
                    state_label = state_map.get(int(state_part), state_part)
                    next_state_label = state_map.get(next_state_alt, next_state_alt)
                else:
                    state_label = state_part
                    next_state_label = next_state
                
                if symbol_map:
                    symbol_label = symbol_map.get(symbol_alt, symbol)
                else:
                    symbol_label = symbol
                
                title = f'{state_label}:{symbol_label}→{next_state_label}'
            else:
                title = f'Transition {transition}'
            
            ax.set_title(title)
            ax.set_xlabel('Intervention Count')
            ax.set_ylabel('KL Contribution')
            ax.set_ylim(y_min, y_max)
            ax.grid(True, linestyle='--', alpha=0.7)
            sns.despine(ax=ax)
        
        # Hide unused subplots
        for j in range(i+1, len(axes_flat)):
            axes_flat[j].set_visible(False)
        
        fig.suptitle(main_title)
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        filename_suffix = f"_part{fig_idx+1}" if n_figures > 1 else ""
        plt.savefig(os.path.join(output_dir, f'transition_contributions_plot{filename_suffix}.png'), dpi=300)
        plt.savefig(os.path.join(output_dir, f'transition_contributions_plot{filename_suffix}.pdf'))
        plt.close()

def main():
    parser = argparse.ArgumentParser(description="Plot KL decomposition results.")
    parser.add_argument('--base-dir', default='experiments_testing/models', help='Base directory containing model folders')
    parser.add_argument('--intervention-start', type=int, default=50, help='Starting intervention count')
    parser.add_argument('--intervention-end', type=int, default=1000, help='Ending intervention count')
    parser.add_argument('--intervention-step', type=int, default=100, help='Intervention count step size')
    parser.add_argument('--output-dir', default='experiments_testing/plots', help='Directory to save plots')
    parser.add_argument('--automata-name', default=None, help='Name of automaton in register to use for state/symbol mapping')
    args = parser.parse_args()
    
    # Initialize state and symbol maps if automata name is provided
    state_map = None
    symbol_map = None
    
    if args.automata_name is not None:
        try:
            automata = AUTOMATA_REGISTER[args.automata_name]
            state_map = automata.state_map
            symbol_map = automata.symbol_map
            print(f"Using state and symbol mappings from {args.automata_name} automaton")
        except KeyError:
            print(f"Warning: Automaton '{args.automata_name}' not found in register. Using default labels.")
        except AttributeError:
            print(f"Warning: Automaton '{args.automata_name}' does not have mapping attributes. Using default labels.")

    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Set plot style
    set_plot_style()
    
    title = args.output_dir.split("/")[-1]

    # Load data
    print(f"Loading KL data from {args.base_dir}...")
    data, states, symbols, transitions = load_kl_data(
        args.base_dir, 
        args.intervention_start, 
        args.intervention_end, 
        args.intervention_step
    )
    
    if not data['intervention_counts']:
        print("No data found. Check if the KL decomposition files exist.")
        return
    
    # Generate plots
    print("Generating total KL plot...")
    plot_total_kl(data, args.output_dir, title)

    print("Generating state contribution plots...")
    plot_state_contributions(data, states, args.output_dir, state_map, title)
    
    print("Generating symbol contribution plots...")
    plot_symbol_contributions(data, symbols, args.output_dir, symbol_map, title)
    
    print("Generating transition contribution plots...")
    plot_transition_contributions(data, transitions, args.output_dir, state_map, symbol_map, title)
    
    print(f"All plots saved to {args.output_dir}")
    
    # Create a summary CSV file with all the data
    summary_data = {'intervention_count': data['intervention_counts'], 'total_kl': data['total_kl']}
    
    # Add state, symbol, and transition data with mapped names in column descriptions
    for state in states:
        state_name = state_map.get(int(state), state) if state_map else state
        summary_data[f'state_{state}_{state_name}'] = data['states'][state]
    
    for symbol in symbols:
        symbol_alt = symbol
        if symbol.isnumeric():
            symbol_alt = int(symbol)
        symbol_name = symbol_map.get(symbol_alt, symbol) if symbol_map else symbol
        summary_data[f'symbol_{symbol}_{symbol_name}'] = data['symbols'][symbol]
    
    for transition in transitions:
        # Replace special characters in column names
        clean_transition = transition.replace('-', '_').replace('>', '_to_').replace('<', '').replace(':', '_')
        summary_data[f'transition_{clean_transition}'] = data['transitions'][transition]
    
    # Create and save the DataFrame
    df = pd.DataFrame(summary_data)
    csv_path = os.path.join(args.output_dir, 'kl_decomposition_summary.csv')
    df.to_csv(csv_path, index=False)
    print(f"Summary data saved to {csv_path}")

if __name__ == "__main__":
    main()