"""
Visualization utilities for Fisher dimension experiments.
"""

from __future__ import annotations

from typing import Dict, List, Optional, Tuple, Union
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.figure import Figure
from matplotlib.axes import Axes

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['font.size'] = 12


def save_figure(
    fig: Figure,
    path: Union[str, Path],
    dpi: int = 300,
    formats: List[str] = None
) -> None:
    """
    Save figure to file(s).

    Args:
        fig: Figure to save
        path: Output path (without extension)
        dpi: Resolution
        formats: List of formats (default: ['pdf', 'png'])
    """
    from .io import ensure_dir

    if formats is None:
        formats = ['pdf', 'png']

    path = Path(path)
    ensure_dir(path.parent)

    for fmt in formats:
        output_path = path.with_suffix(f'.{fmt}')
        fig.savefig(output_path, dpi=dpi, bbox_inches='tight')

    plt.close(fig)


def plot_fisher_vs_sample_complexity(
    fisher_dims: np.ndarray,
    sample_complexities: np.ndarray,
    graph_families: Optional[List[str]] = None,
    d: int = 10,
    delta: float = 0.1,
    fitted_constant: Optional[float] = None,
    max_n: Optional[float] = None,
    ax: Optional[Axes] = None
) -> Tuple[Figure, Axes]:
    """
    Plot Fisher dimension vs empirical sample complexity.

    Theory (Theorem 4.1): n* = C · F([G]) · log(d/δ)

    Args:
        fisher_dims: Array of Fisher dimensions
        sample_complexities: Array of sample complexities n*
        graph_families: Optional list of family labels for coloring
        d: Number of nodes
        delta: Failure probability (1 - success_threshold). Default 0.1.
        fitted_constant: Optional constant C to scale x-axis. If None, fitted from data.
        max_n: Maximum sample size (to exclude ceiling cases from C fitting)
        ax: Optional existing axes

    Returns:
        (figure, axes)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    else:
        fig = ax.figure

    # Compute raw theoretical prediction: F([G]) * log(d/δ)
    log_factor = np.log(d) + np.log(1.0 / delta)  # log(d/δ) = log(d) + log(1/δ)
    x_raw = fisher_dims * log_factor

    # Fit constant C if not provided (excluding ceiling cases)
    if fitted_constant is None:
        valid_for_fit = np.isfinite(x_raw) & np.isfinite(sample_complexities) & (x_raw > 0)
        if max_n is not None:
            valid_for_fit &= (sample_complexities < max_n)
        if np.sum(valid_for_fit) > 0:
            ratios = sample_complexities[valid_for_fit] / x_raw[valid_for_fit]
            fitted_constant = float(np.median(ratios))
        else:
            fitted_constant = 1.0

    # Use raw x-axis and show fitted line
    x = x_raw  # F([G]) * log(d)

    # Filter out ceiling cases from plotting
    if max_n is not None:
        plot_mask = sample_complexities < max_n
        x_plot = x[plot_mask]
        y_plot = sample_complexities[plot_mask]
        if graph_families is not None:
            families_plot = [f for f, m in zip(graph_families, plot_mask) if m]
        else:
            families_plot = None
    else:
        x_plot = x
        y_plot = sample_complexities
        families_plot = graph_families

    if families_plot is not None:
        unique_families = list(set(families_plot))
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_families)))
        family_to_color = dict(zip(unique_families, colors))

        for family in unique_families:
            mask = np.array([f == family for f in families_plot])
            ax.scatter(
                x_plot[mask], y_plot[mask],
                c=[family_to_color[family]],
                label=family,
                alpha=0.7,
                s=50
            )
        ax.legend(title='Graph Family')
    else:
        ax.scatter(x_plot, y_plot, alpha=0.7, s=50)

    # Add fitted line: y = C * x (theory prediction)
    x_line = np.array([0, np.max(x) * 1.1])
    y_line = fitted_constant * x_line
    ax.plot(x_line, y_line, 'r-', linewidth=2, label=f'Fitted: $n^* = {fitted_constant:.1f} \\cdot x$')

    ax.set_xlabel(r'$\mathcal{F}([G]) \cdot \log(d/\delta)$')
    ax.set_ylabel('Empirical Sample Complexity $n^*$')
    ax.set_title('Fisher Dimension vs Sample Complexity')
    ax.legend(loc='upper left')

    # Add correlation annotation (using filtered data)
    from scipy.stats import pearsonr, spearmanr
    valid = np.isfinite(x_plot) & np.isfinite(y_plot)
    if np.sum(valid) > 2:
        r_pearson, _ = pearsonr(x_plot[valid], y_plot[valid])
        r_spearman, _ = spearmanr(x_plot[valid], y_plot[valid])
        ax.annotate(
            f'Pearson r = {r_pearson:.3f}\nSpearman ρ = {r_spearman:.3f}\nFitted C = {fitted_constant:.2f}',
            xy=(0.98, 0.02), xycoords='axes fraction',
            verticalalignment='bottom',
            horizontalalignment='right',
            fontsize=10,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

    return fig, ax


def plot_bounds_tightness(
    predicted: np.ndarray,
    empirical: np.ndarray,
    families: List[str],
    fitted_constant: Optional[float] = None,
    max_n: Optional[float] = None,
    ax: Optional[Axes] = None
) -> Tuple[Figure, Axes]:
    """
    Plot predicted vs empirical sample complexity for bounds tightness.

    Args:
        predicted: Predicted n values
        empirical: Empirical n* values
        families: Family labels
        fitted_constant: The fitted C value for annotation
        max_n: Maximum sample size (to exclude ceiling cases from plot)
        ax: Optional existing axes

    Returns:
        (figure, axes)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    else:
        fig = ax.figure

    # Filter out ceiling cases from plotting
    if max_n is not None:
        plot_mask = empirical < max_n
        predicted_plot = predicted[plot_mask]
        empirical_plot = empirical[plot_mask]
        families_plot = [f for f, m in zip(families, plot_mask) if m]
    else:
        predicted_plot = predicted
        empirical_plot = empirical
        families_plot = families

    unique_families = list(set(families_plot))
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_families)))

    for i, family in enumerate(unique_families):
        mask = np.array([f == family for f in families_plot])
        ax.scatter(
            predicted_plot[mask], empirical_plot[mask],
            c=[colors[i]], label=family, alpha=0.7, s=50
        )

    # Reference lines (use filtered data for range)
    min_val = max(1, min(np.min(predicted_plot), np.min(empirical_plot)))
    max_val = max(np.max(predicted_plot), np.max(empirical_plot))

    ax.plot([min_val, max_val], [min_val, max_val], 'k-', alpha=0.5, label='Ratio = 1')
    ax.plot([min_val, max_val], [0.5*min_val, 0.5*max_val], 'k--', alpha=0.3, label='Ratio = 0.5')
    ax.plot([min_val, max_val], [2*min_val, 2*max_val], 'k--', alpha=0.3, label='Ratio = 2')

    ax.set_xlabel(r'Predicted $n$ (theory): $C \cdot \mathcal{F}([G]) \cdot \log(d/\delta)$')
    ax.set_ylabel('Empirical $n^*$')
    ax.set_title('Bounds Tightness')
    ax.legend(loc='upper left')

    # Add fitted C annotation
    if fitted_constant is not None:
        ax.annotate(
            f'Fitted C = {fitted_constant:.2f}',
            xy=(0.98, 0.02), xycoords='axes fraction',
            verticalalignment='bottom',
            horizontalalignment='right',
            fontsize=10,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

    return fig, ax


def plot_bounds_tightness_multipanel(
    families_data: Dict[str, Dict],
    max_n: Optional[float] = None,
) -> Tuple[Figure, List[Axes]]:
    """
    Plot predicted vs empirical sample complexity in multi-panel layout.
    Each family gets its own panel with its own fitted C.

    Args:
        families_data: Dict mapping family_name to dict with keys:
            - 'predicted': array of predicted n values
            - 'empirical_n': array of empirical n* values
            - 'fitted_C': fitted constant for this family
            - 'within_factor_2': proportion within factor 2
        max_n: Maximum sample size (to exclude ceiling cases from plot)

    Returns:
        (figure, list of axes)
    """
    n_families = len(families_data)
    fig, axes = plt.subplots(1, n_families, figsize=(5 * n_families, 5))

    if n_families == 1:
        axes = [axes]

    family_colors = {'low_F': '#2ecc71', 'medium_F': '#3498db', 'high_F': '#e74c3c'}

    for ax, (family_name, family_data) in zip(axes, families_data.items()):
        predicted = np.array(family_data['predicted'])
        empirical = np.array(family_data['empirical_n'])
        fitted_C = family_data.get('fitted_C', 1.0)
        within_factor_2 = family_data.get('within_factor_2', float('nan'))

        # Filter ceiling cases
        if max_n is not None:
            mask = empirical < max_n
            predicted = predicted[mask]
            empirical = empirical[mask]

        color = family_colors.get(family_name, '#3498db')
        ax.scatter(predicted, empirical, c=color, alpha=0.7, s=50)

        # Reference lines
        if len(predicted) > 0 and len(empirical) > 0:
            min_val = max(1, min(np.min(predicted), np.min(empirical)))
            max_val = max(np.max(predicted), np.max(empirical))

            ax.plot([min_val, max_val], [min_val, max_val], 'k-', alpha=0.5, label='Ratio = 1')
            ax.plot([min_val, max_val], [0.5*min_val, 0.5*max_val], 'k--', alpha=0.3, label='Ratio = 0.5')
            ax.plot([min_val, max_val], [2*min_val, 2*max_val], 'k--', alpha=0.3, label='Ratio = 2')

        ax.set_xlabel(r'Predicted $n$')
        ax.set_ylabel('Empirical $n^*$')
        ax.set_title(f'{family_name}')
        ax.legend(loc='upper left', fontsize=8)

        # Annotation
        ax.annotate(
            f'C = {fitted_C:.1f}\nwithin 2x: {within_factor_2:.0%}',
            xy=(0.98, 0.02), xycoords='axes fraction',
            verticalalignment='bottom',
            horizontalalignment='right',
            fontsize=10,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

    fig.suptitle('Bounds Tightness (per beta family)', fontsize=14)
    plt.tight_layout()

    return fig, axes


def plot_fisher_vs_complexity_multipanel(
    by_beta_range: Dict[str, Dict],
    stats_by_beta: Dict[str, Dict],
    d: int = 10,
    delta: float = 0.1,
    max_n: Optional[float] = None,
) -> Tuple[Figure, List[Axes]]:
    """
    Plot Fisher dimension vs sample complexity in multi-panel layout.
    Each beta range gets its own panel with its own fitted C.

    Args:
        by_beta_range: Dict mapping beta_range_str to dict with keys:
            - 'fisher_dims': list of Fisher dimensions
            - 'sample_complexities': list of n* values
            - 'fitted_C': fitted constant
        stats_by_beta: Dict mapping beta_range_str to dict with keys:
            - 'fitted_C', 'pearson_r', 'spearman_r'
        d: Number of nodes
        delta: Failure probability
        max_n: Maximum sample size (to exclude ceiling cases)

    Returns:
        (figure, list of axes)
    """
    n_panels = len(by_beta_range)
    fig, axes = plt.subplots(1, n_panels, figsize=(4 * n_panels, 4))

    if n_panels == 1:
        axes = [axes]

    log_factor = np.log(d) + np.log(1.0 / delta)

    for ax, (br_key, br_data) in zip(axes, by_beta_range.items()):
        fisher_dims = np.array(br_data['fisher_dims'])
        sample_complexities = np.array(br_data['sample_complexities'])
        stats = stats_by_beta.get(br_key, {})
        fitted_C = stats.get('fitted_C', 1.0)
        pearson_r = stats.get('pearson_r', float('nan'))
        spearman_r = stats.get('spearman_r', float('nan'))

        # x = F * log(d/δ)
        x = fisher_dims * log_factor

        # Filter ceiling cases
        if max_n is not None:
            mask = sample_complexities < max_n
            x = x[mask]
            sample_complexities = sample_complexities[mask]

        ax.scatter(x, sample_complexities, alpha=0.7, s=40)

        # Fitted line
        if len(x) > 0:
            x_line = np.array([0, np.max(x) * 1.1])
            y_line = fitted_C * x_line
            ax.plot(x_line, y_line, 'r-', linewidth=2, alpha=0.7)

        # Parse beta range for title
        ax.set_xlabel(r'$\mathcal{F} \cdot \log(d/\delta)$')
        ax.set_ylabel('$n^*$')
        ax.set_title(f'β: {br_key}', fontsize=10)

        # Annotation
        ax.annotate(
            f'C={fitted_C:.0f}\nr={pearson_r:.2f}',
            xy=(0.98, 0.02), xycoords='axes fraction',
            verticalalignment='bottom',
            horizontalalignment='right',
            fontsize=9,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

    fig.suptitle('Fisher Dimension vs Sample Complexity (per beta range)', fontsize=12)
    plt.tight_layout()

    return fig, axes


def plot_exp3_by_graph_type(
    by_graph_type: Dict[str, Dict],
) -> Tuple[Figure, List[Axes]]:
    """
    Plot proxy correlation comparison in 3x1 panel layout (one per graph type).

    Args:
        by_graph_type: Dict mapping graph_type to dict with:
            - 'correlations': dict mapping proxy_name to correlation value
            - 'best_proxy': name of best proxy
            - 'fisher_correlation': correlation for fisher_dimension

    Returns:
        (figure, list of axes)
    """
    graph_types = list(by_graph_type.keys())
    n_types = len(graph_types)

    fig, axes = plt.subplots(1, n_types, figsize=(6 * n_types, 5))

    if n_types == 1:
        axes = [axes]

    type_colors = {'tree': '#2ecc71', 'chain': '#3498db', 'erdos_renyi': '#e74c3c'}

    for ax, gt in zip(axes, graph_types):
        gt_data = by_graph_type[gt]
        correlations = gt_data.get('correlations', {})

        if not correlations:
            ax.set_title(f'{gt} (no data)')
            continue

        # Sort by absolute correlation
        sorted_items = sorted(
            correlations.items(),
            key=lambda x: abs(x[1]) if not np.isnan(x[1]) else -float('inf')
        )

        names = [item[0] for item in sorted_items]
        values = [item[1] for item in sorted_items]

        # Color bars: highlight fisher_dimension
        bar_colors = []
        for name, val in zip(names, values):
            if name == 'fisher_dimension':
                bar_colors.append('#e74c3c')  # Red for fisher_dimension
            elif val > 0:
                bar_colors.append('#2ecc71')  # Green for positive
            else:
                bar_colors.append('#95a5a6')  # Gray for negative

        bars = ax.barh(names, values, color=bar_colors, alpha=0.8)

        ax.axvline(x=0, color='black', linewidth=0.5)
        ax.set_xlabel('Spearman Correlation')
        ax.set_title(f'{gt}')

        # Add value labels
        for bar, val in zip(bars, values):
            if not np.isnan(val):
                width = bar.get_width()
                ax.text(
                    width + 0.02 if width >= 0 else width - 0.02,
                    bar.get_y() + bar.get_height()/2,
                    f'{val:.2f}',
                    va='center',
                    ha='left' if width >= 0 else 'right',
                    fontsize=8
                )

        # Annotate best proxy
        best_proxy = gt_data.get('best_proxy', '')
        fisher_corr = gt_data.get('fisher_correlation', float('nan'))
        ax.annotate(
            f'best: {best_proxy}\nF_dim: {fisher_corr:.2f}',
            xy=(0.98, 0.02), xycoords='axes fraction',
            verticalalignment='bottom',
            horizontalalignment='right',
            fontsize=9,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

    fig.suptitle('Complexity Proxy Comparison (per graph type)', fontsize=14)
    plt.tight_layout()

    return fig, axes


def plot_exp2_grid(
    by_graph_type: Dict[str, Dict],
    stats_by_cell: Dict[str, Dict],
    max_n: Optional[float] = None,
) -> Tuple[Figure, np.ndarray]:
    """
    Plot predicted vs empirical in 3x3 grid (3 graph types × 3 beta families).

    Args:
        by_graph_type: Dict mapping graph_type to dict with 'families' key,
            where families maps family_name to dict with:
            - 'predicted': array of predicted n values
            - 'empirical_n': array of empirical n* values
        stats_by_cell: Dict mapping graph_type to dict mapping family_name to:
            - 'fitted_C', 'within_factor_2', 'n_graphs'
        max_n: Maximum sample size (to exclude ceiling cases)

    Returns:
        (figure, 2D array of axes)
    """
    graph_types = list(by_graph_type.keys())
    # Get family names from first graph type
    family_names = list(by_graph_type[graph_types[0]]['families'].keys())

    n_rows = len(graph_types)
    n_cols = len(family_names)

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

    type_colors = {'tree': '#2ecc71', 'chain': '#3498db', 'erdos_renyi': '#e74c3c'}

    for i, gt in enumerate(graph_types):
        for j, fn in enumerate(family_names):
            ax = axes[i, j] if n_rows > 1 else axes[j]
            cell_data = by_graph_type[gt]['families'][fn]
            stats = stats_by_cell.get(gt, {}).get(fn, {})

            predicted = np.array(cell_data.get('predicted', []))
            empirical = np.array(cell_data.get('empirical_n', []))
            fitted_C = stats.get('fitted_C', 1.0)
            within_2x = stats.get('within_factor_2', float('nan'))

            # Filter ceiling cases
            if max_n is not None and len(empirical) > 0:
                mask = empirical < max_n
                predicted = predicted[mask]
                empirical = empirical[mask]

            color = type_colors.get(gt, '#3498db')
            if len(predicted) > 0:
                ax.scatter(predicted, empirical, c=color, alpha=0.7, s=30)

                # Reference lines
                min_val = max(1, min(np.min(predicted), np.min(empirical)))
                max_val = max(np.max(predicted), np.max(empirical))

                ax.plot([min_val, max_val], [min_val, max_val], 'k-', alpha=0.5, linewidth=1)
                ax.plot([min_val, max_val], [0.5*min_val, 0.5*max_val], 'k--', alpha=0.3, linewidth=1)
                ax.plot([min_val, max_val], [2*min_val, 2*max_val], 'k--', alpha=0.3, linewidth=1)

            # Labels
            if i == n_rows - 1:
                ax.set_xlabel('Predicted $n$')
            if j == 0:
                ax.set_ylabel('Empirical $n^*$')

            # Title for top row only
            if i == 0:
                ax.set_title(fn, fontsize=11)

            # Row label on right side
            if j == n_cols - 1:
                ax.annotate(gt, xy=(1.02, 0.5), xycoords='axes fraction',
                           rotation=-90, va='center', ha='left', fontsize=10)

            # Stats annotation
            ax.annotate(
                f'C={fitted_C:.0f}\n2x:{within_2x:.0%}',
                xy=(0.98, 0.02), xycoords='axes fraction',
                verticalalignment='bottom',
                horizontalalignment='right',
                fontsize=9,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
            )

    fig.suptitle('Bounds Tightness (graph type × beta family)', fontsize=14)
    plt.tight_layout()

    return fig, axes


def plot_exp1_by_graph_type(
    by_graph_type: Dict[str, Dict],
    stats_by_type: Dict[str, Dict],
    d: int = 10,
    delta: float = 0.1,
    max_n: Optional[float] = None,
) -> Tuple[Figure, List[Axes]]:
    """
    Plot Fisher dimension vs sample complexity in 3x1 panel layout (one per graph type).
    Each panel shows overall correlation for that graph type.

    Args:
        by_graph_type: Dict mapping graph_type to dict with keys:
            - 'fisher_dims': list of Fisher dimensions
            - 'sample_complexities': list of n* values
        stats_by_type: Dict mapping graph_type to dict with keys:
            - 'fitted_C', 'pearson_r', 'spearman_r', 'n_graphs'
        d: Number of nodes
        delta: Failure probability
        max_n: Maximum sample size (to exclude ceiling cases)

    Returns:
        (figure, list of axes)
    """
    n_types = len(by_graph_type)
    fig, axes = plt.subplots(1, n_types, figsize=(5 * n_types, 5))

    if n_types == 1:
        axes = [axes]

    log_factor = np.log(d) + np.log(1.0 / delta)

    type_colors = {'tree': '#2ecc71', 'chain': '#3498db', 'erdos_renyi': '#e74c3c'}

    for ax, (gt, gt_data) in zip(axes, by_graph_type.items()):
        fisher_dims = np.array(gt_data['fisher_dims'])
        sample_complexities = np.array(gt_data['sample_complexities'])
        stats = stats_by_type.get(gt, {})
        fitted_C = stats.get('fitted_C', 1.0)
        pearson_r = stats.get('pearson_r', float('nan'))
        spearman_r = stats.get('spearman_r', float('nan'))
        n_graphs = stats.get('n_graphs', len(fisher_dims))

        # x = F * log(d/δ)
        x = fisher_dims * log_factor

        # Filter ceiling cases
        if max_n is not None:
            mask = sample_complexities < max_n
            x = x[mask]
            sample_complexities = sample_complexities[mask]

        color = type_colors.get(gt, '#3498db')
        ax.scatter(x, sample_complexities, c=color, alpha=0.7, s=50)

        # Fitted line
        if len(x) > 0:
            x_line = np.array([0, np.max(x) * 1.1])
            y_line = fitted_C * x_line
            ax.plot(x_line, y_line, 'r-', linewidth=2, alpha=0.7)

        ax.set_xlabel(r'$\mathcal{F} \cdot \log(d/\delta)$')
        ax.set_ylabel('$n^*$')
        ax.set_title(f'{gt} (n={n_graphs})')

        # Annotation
        ax.annotate(
            f'C={fitted_C:.0f}\nr={spearman_r:.2f}',
            xy=(0.98, 0.02), xycoords='axes fraction',
            verticalalignment='bottom',
            horizontalalignment='right',
            fontsize=10,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
        )

    fig.suptitle('Fisher Dimension vs Sample Complexity (per graph type)', fontsize=14)
    plt.tight_layout()

    return fig, axes


def plot_proxy_comparison(
    correlations: Dict[str, float],
    ax: Optional[Axes] = None
) -> Tuple[Figure, Axes]:
    """
    Plot bar chart comparing proxy correlations.

    Args:
        correlations: Dict mapping proxy name to correlation
        ax: Optional existing axes

    Returns:
        (figure, axes)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))
    else:
        fig = ax.figure

    # Sort by absolute correlation
    sorted_items = sorted(
        correlations.items(),
        key=lambda x: -abs(x[1]) if not np.isnan(x[1]) else -float('inf')
    )

    names = [item[0] for item in sorted_items]
    values = [item[1] for item in sorted_items]

    colors = ['#2ecc71' if v > 0 else '#e74c3c' for v in values]

    bars = ax.barh(names, values, color=colors, alpha=0.7)

    ax.axvline(x=0, color='black', linewidth=0.5)
    ax.set_xlabel('Spearman Correlation')
    ax.set_title('Complexity Proxy Comparison')

    # Add value labels
    for bar, val in zip(bars, values):
        width = bar.get_width()
        ax.text(
            width + 0.02 if width >= 0 else width - 0.02,
            bar.get_y() + bar.get_height()/2,
            f'{val:.3f}',
            va='center',
            ha='left' if width >= 0 else 'right',
            fontsize=9
        )

    return fig, ax


def plot_scaling(
    d_values: List[int],
    sample_complexities: Dict[str, np.ndarray],
    ax: Optional[Axes] = None
) -> Tuple[Figure, Axes]:
    """
    Plot sample complexity scaling with graph size.

    Args:
        d_values: Node counts
        sample_complexities: Dict mapping family to array of n* values
        ax: Optional existing axes

    Returns:
        (figure, axes)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    else:
        fig = ax.figure

    colors = plt.cm.tab10(np.linspace(0, 1, len(sample_complexities)))

    for i, (family, n_values) in enumerate(sample_complexities.items()):
        ax.plot(d_values, n_values, 'o-', color=colors[i], label=family, markersize=8)

    # Add theoretical log(d) scaling reference
    log_d = np.log(np.array(d_values))
    scale = np.mean(list(sample_complexities.values())[0]) / np.mean(log_d)
    ax.plot(d_values, scale * log_d, 'k--', alpha=0.5, label=r'$\propto \log d$')

    ax.set_xlabel('Number of nodes $d$')
    ax.set_ylabel('Sample Complexity $n^*$')
    ax.set_title('Scaling with Graph Size')
    ax.legend()

    return fig, ax


def plot_lower_bound_verification(
    fisher_dims: np.ndarray,
    empirical_n: np.ndarray,
    theoretical_lower: np.ndarray,
    ax: Optional[Axes] = None
) -> Tuple[Figure, Axes]:
    """
    Plot lower bound verification.

    Args:
        fisher_dims: Fisher dimension values
        empirical_n: Empirical sample complexities
        theoretical_lower: Theoretical lower bounds
        ax: Optional existing axes

    Returns:
        (figure, axes)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    else:
        fig = ax.figure

    ax.scatter(fisher_dims, empirical_n, label='Empirical $n^*$', alpha=0.7, s=50)
    ax.scatter(fisher_dims, theoretical_lower, label='Lower bound', alpha=0.5, s=30, marker='^')

    # Diagonal reference
    ax.plot(fisher_dims, fisher_dims, 'k--', alpha=0.5, label=r'$n = \mathcal{F}([G])$')

    ax.set_xlabel(r'Fisher Dimension $\mathcal{F}([G])$')
    ax.set_ylabel('Sample Size $n$')
    ax.set_title('Lower Bound Verification')
    ax.legend()

    return fig, ax


def plot_benchmark_comparison(
    benchmark_names: List[str],
    fisher_direct: List[float],
    fisher_curvature: List[float],
    ax: Optional[Axes] = None
) -> Tuple[Figure, Axes]:
    """
    Plot benchmark graph Fisher dimension comparison.

    Args:
        benchmark_names: Names of benchmark graphs
        fisher_direct: Fisher dimensions (direct method)
        fisher_curvature: Fisher dimensions (curvature method)
        ax: Optional existing axes

    Returns:
        (figure, axes)
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))
    else:
        fig = ax.figure

    x = np.arange(len(benchmark_names))
    width = 0.35

    bars1 = ax.bar(x - width/2, fisher_direct, width, label='Direct (Definition 3.1)')
    bars2 = ax.bar(x + width/2, fisher_curvature, width, label='Curvature (Conjecture 4.1)')

    ax.set_xlabel('Benchmark Graph')
    ax.set_ylabel(r'Fisher Dimension $\mathcal{F}([G])$')
    ax.set_title('Benchmark Graph Analysis')
    ax.set_xticks(x)
    ax.set_xticklabels(benchmark_names, rotation=45, ha='right')
    ax.legend()

    # Add value labels
    def add_labels(bars):
        for bar in bars:
            height = bar.get_height()
            ax.text(
                bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}',
                ha='center', va='bottom', fontsize=8
            )

    add_labels(bars1)
    add_labels(bars2)

    return fig, ax


def create_summary_figure(
    results: Dict[str, Dict]
) -> Figure:
    """
    Create a summary figure with multiple subplots.

    Args:
        results: Dict with results from multiple experiments

    Returns:
        Multi-panel figure
    """
    fig = plt.figure(figsize=(16, 12))

    # 2x3 grid
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

    # Panel 1: Fisher vs sample complexity
    if 'exp1' in results:
        ax1 = fig.add_subplot(gs[0, 0])
        data = results['exp1']
        plot_fisher_vs_sample_complexity(
            np.array(data['fisher_dims']),
            np.array(data['sample_complexities']),
            data.get('families'),
            data.get('d', 10),
            ax=ax1
        )

    # Panel 2: Bounds tightness
    if 'exp2' in results:
        ax2 = fig.add_subplot(gs[0, 1])
        data = results['exp2']
        plot_bounds_tightness(
            np.array(data['predicted']),
            np.array(data['empirical']),
            data['families'],
            ax=ax2
        )

    # Panel 3: Proxy comparison
    if 'exp3' in results:
        ax3 = fig.add_subplot(gs[0, 2])
        plot_proxy_comparison(results['exp3']['correlations'], ax=ax3)

    # Panel 4: Scaling
    if 'exp4' in results:
        ax4 = fig.add_subplot(gs[1, 0])
        plot_scaling(
            results['exp4']['d_values'],
            results['exp4']['sample_complexities'],
            ax=ax4
        )

    # Panel 5: Lower bound
    if 'exp5' in results:
        ax5 = fig.add_subplot(gs[1, 1])
        data = results['exp5']
        plot_lower_bound_verification(
            np.array(data['fisher_dims']),
            np.array(data['empirical_n']),
            np.array(data['theoretical_lower']),
            ax=ax5
        )

    # Panel 6: Benchmark
    if 'exp6' in results:
        ax6 = fig.add_subplot(gs[1, 2])
        data = results['exp6']
        plot_benchmark_comparison(
            data['names'],
            data['fisher_direct'],
            data['fisher_curvature'],
            ax=ax6
        )

    fig.suptitle('Fisher Dimension Experimental Results', fontsize=14, fontweight='bold')

    return fig


def plot_graph_structure(
    dag: 'DAG',
    ax: Optional[Axes] = None,
    node_size: int = 500,
    font_size: int = 10
) -> Tuple[Figure, Axes]:
    """
    Visualize a DAG structure.

    Args:
        dag: DAG to visualize
        ax: Optional existing axes
        node_size: Size of nodes
        font_size: Font size for labels

    Returns:
        (figure, axes)
    """
    import networkx as nx

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    else:
        fig = ax.figure

    # Create networkx graph
    G = nx.DiGraph()
    G.add_nodes_from(range(dag.num_nodes()))
    G.add_edges_from(dag.edges)

    # Layout
    try:
        pos = nx.planar_layout(G)
    except nx.NetworkXException:
        pos = nx.spring_layout(G, seed=42)

    # Draw
    nx.draw(
        G, pos, ax=ax,
        node_color='lightblue',
        node_size=node_size,
        font_size=font_size,
        font_weight='bold',
        arrows=True,
        arrowsize=20,
        with_labels=True
    )

    ax.set_title(f'DAG ({dag.num_nodes()} nodes, {dag.num_edges()} edges)')

    return fig, ax
