# Orthogonality Figure Generation Configuration
# Usage: uv run python experiments/figure_generation/orthogonality/run.py run_id=<your_run_id>

# MLflow run settings (required)
run_id: "<run-id>"  # Required: MLflow run ID
experiment_id: "<experiment-id>"  # Optional: MLflow experiment ID (inferred from run if not provided)
tracking_uri: "${oc.env:MLFLOW_TRACKING_URI,local}"
registry_uri: "${oc.env:MLFLOW_REGISTRY_URI,local}"

# Layer selection
# null = auto-select last resid_post layer
# Examples: "blocks.3.hook_resid_post", "ln_final.hook_normalized"
layer: "blocks.1.hook_resid_post"

# Checkpoint selection
# Option 1: Auto-select evenly-spaced checkpoints
n_checkpoints: 4  # Number of evenly-spaced checkpoints to analyze (used if checkpoint_steps is null)
max_step: null  # Filter checkpoints <= max_step (null = no filter)

# Option 2: Explicit checkpoint steps (overrides n_checkpoints if provided)
checkpoint_steps: [0, 50, 100, 200, 300, 400, 500, 600, 800, 1000, 80000, 200000, 500000]  # Focused on first 600 steps

# Per-panel step filtering (subsets of checkpoint_steps)
# null = use all steps from checkpoint_steps
dims_steps: [0, 100, 200, 300, 400, 500, 600, 800, 1000, 80000, 200000, 500000]  # 10 steps for dims plot
orthogonality_steps: [0, 200, 400, 1000, 200000, 500000]  # 4 steps for orthogonality plot

# Orthogonality analysis parameters
max_k: 20  # Maximum k for orthogonality curves
num_frozen_points: 10  # Frozen configurations per factor
batch_per_frozen: 200  # Batch size per frozen point
seed: 42  # Base seed for reproducibility
n_seed_iterations: 1  # Number of independent seed iterations (1 = current behavior, >1 for averaging)

# Compute device for SVD operations (the main bottleneck)
# "cuda" = use GPU via JAX (falls back to CPU if unavailable)
# "cpu" = force CPU
# "auto" = try GPU, fall back to CPU
compute_device: auto

# Baseline settings (random init orthogonality baseline)
baseline:
  # Explicit path, or null for auto-detection based on run_id
  # Auto-detected path: experiments/figure_generation/orthogonality/baseline_{run_id}.pkl
  path: null  # Auto-detect based on run_id
  auto_generate: false  # Generate baseline if missing (slow: runs N random models)
  n_models: 5  # Number of random initializations (use 2-3 for quick testing)
  verify_compatibility: true  # Check baseline matches architecture
  ci_percentile: 90  # CI percentile (90 = 5th-95th)

# Belief regression settings (for bottom row of composite figure)
# Projects activations onto vary-one PCA subspaces and regresses to predict beliefs
belief_regression:
  enabled: true               # Enable/disable belief regression in composite figure
  batch_size: 256             # Batch size for validation data generation
  seed: 12345                 # Seed for validation data generation
  # Per-factor k values for projection (number of PCA components)
  # null = use intrinsic dimension (state_dim - 1) for each factor
  # or specify as list, e.g., [2, 2, 2, 3, 3] for 5 factors
  k_per_factor:
    - 2
    - 2
    - 2
    - 2
    - 2
  # Which checkpoint to use for belief regression
  # null = use last checkpoint from n_checkpoints
  checkpoint_step: null
  # Plotting options
  max_samples: 10000          # Max samples to plot (subsampled for performance)
  marker_size: 0.5            # Scatter marker size
  marker_opacity: 0.7         # Scatter marker opacity
  pred_marker_size: 0.75      # Prediction marker size
  pred_marker_opacity: 0.35   # Prediction marker opacity
  scatter_pad_frac: 0.0      # Padding fraction around scatter data (default 0.1)
  title_pad: 0                # Padding for Theory/Activations headers (default 2)
  preserve_aspect: true      # True = square plots, False = fill available space

# Caching settings
cache:
  enabled: true              # Enable/disable caching of orthogonality data
  path: null                 # Cache file path (null = auto-generate from parameters)
  force_recompute: false    # Force recomputation even if cache exists

# Output settings
output_path: figure_3_2layer.png
output_format: png  # png, pdf, svg
dpi: 300

# Composite figure layout
# Layout options:
#   "full" (3 columns): 2x2 grid (a) | stacked dims+orth (b,c) | belief regression (d)
#   "regression" (2 columns): belief regression (a) | stacked dims+orth (b,c)
#   "graphs_only" (2 columns side-by-side): dims (a) | orthogonality (b)
composite_figure:
  layout: "graphs_only"  # Options: "full", "regression", "graphs_only"

  # Figure sizes for each layout
  figsize:
    - 6.75   # Overall figure width (inches) for "full" layout
    - 2.8    # Overall figure height (inches)
  figsize_regression:
    - 3.25   # Single ICML column width for "regression" layout
    - 5.0    # Height (taller to accommodate 5 factors)
  figsize_graphs_only:
    - 3.25   # Width for "graphs_only" layout (side-by-side)
    - 2.5    # Height (shorter, single row)

  # Width ratios for "full" layout (3 columns)
  width_ratios:
    - 1.2   # Left column (2x2 images)
    - .9    # Middle column (dims + orthogonality)
    - 0.6   # Right column (belief regression)
  # Width ratios for "regression" layout (2 columns)
  width_ratios_regression:
    - .8    # Left column (belief regression)
    - 1.0   # Right column (dims + orthogonality)
  # Width ratios for "graphs_only" layout (2 columns side-by-side)
  width_ratios_graphs_only:
    - .9   # Left column (dims plot)
    - 1.0   # Right column (orthogonality plot)

  # Manual gridspec positioning for "regression" layout (tighter margins)
  # F labels use ylabel positioning, so left margin can be tight
  regression_left: 0.0
  regression_right: 1.45
  regression_top: 0.8
  regression_bottom: 0.10
  regression_wspace: 0.15  # Gap between belief column and middle plots column
  regression_hspace: 0.3   # Vertical space between rows (enough for tick labels)

  # Per-panel aspect ratios (null = auto from grid cell)
  # Only use for IMAGE panels - data plots should use null to fill available space
  # Example: 1.0 = square, 1.5 = wider than tall, 0.75 = taller than wide
  panel_aspects:
    top_left: null      # Auto for left placeholder image (preserves image aspect)
    top_right: null     # null for data plots - fills grid cell
    bottom_left: null   # (unused in new layout)
    bottom_right: null  # null for data plots - fills grid cell

  # Panel content
  # Top-left panel: 2x2 grid of source images from figure1
  top_left_images:
    base_dir: experiments/figure_generation/figure1
    format: png  # png, pdf, svg
    factored_2d: freeze_vary_2d
    joint_3d: freeze_vary_3d
    factored_2d_centered: freeze_vary_2d_centered
    joint_3d_centered: freeze_vary_3d_centered
  variance_threshold: 0.95  # Threshold for stacked dims plot
  orthogonality_cmap: viridis  # Colormap for orthogonality plot (teal_coral, viridis, plasma, etc.)
  orthogonality_average_linewidth: 1.2  # Linewidth for average lines in orthogonality plot
  orthogonality_marker_size: 4  # Marker size for orthogonality plot lines
  orthogonality_show_baseline_mean: false  # Show Random Init Mean line (false = only show CI band)

  # Theory line settings for stacked dims plot
  theory_lines:
    enabled: false
    show_simple: true        # Theory 1: average-scaled
    show_per_factor: true    # Theory 2: per-factor unique
    linestyle: ":"           # Dotted
    color: "0.7"             # Gray (matplotlib color spec)
    linewidth: 1.0
    marker: "."              # Square markers
    markersize: 1.0

  # Stacked bar mode for dims plot (alternative to line plot)
  # When enabled, shows stacked bars at discrete steps instead of continuous lines
  dims_bar_mode:
    enabled: true           # Toggle for bar chart mode (replaces line plot when true)
    steps: [0, 1000, 200000, 500000]  # Steps to show as bars (default: init, late, final)
    bar_width: 0.7           # Width of bars relative to spacing (0-1)
    # Inset controls (separate from line plot inset)
    inset_enabled: true
    inset_bounds: [0.55, 0.32, 0.36, 0.45]  # Inset position [x, y, w, h]
    inset_xlim: [0, 1, 2,3]         # Inset x-axis limits (step indices), null for auto
    inset_ylim: [0,20]         # Inset y-axis limits, null for auto
    # "All" line styling (horizontal lines per bar, color-coded by step)
    all_linestyle: "--"      # Dashed line
    all_linewidth: 1.3

  # Stacked dims plot (panel b) settings
  dims_cmap: plasma  # Colormap for factor colors (Set2, Pastel1, tab10, etc. - complements viridis)
  show_joint: false  # Show "Joint" line (natural generation, all factors varying together)
  show_union: true   # Show "Union" line (concatenated vary-one, union of factor subspaces)
  combined_linestyle: "k--"  # Linestyle for Joint line (e.g., "k-" solid, "k--" dashed, "k:" dotted)
  combined_linewidth: 1   # Line width for Joint line
  combined_marker_size: 1.5  # Marker size for Joint line
  dims_xlim: null  # X-axis limits for dims plot [min, max], null for auto
  dims_ylim: null  # Y-axis limits for dims plot [min, max], null for auto
  dims_xscale: null  # X-axis scale for dims plot (e.g., "log"), null for linear

  # Inset zoom settings for dims plot (panel b)
  dims_inset_enabled: true  # Enable inset with zoomed region near final training step
  dims_inset_bounds: [0.35, 0.58, 0.42, 0.35]  # Inset position/size [x, y, width, height] in axes coords
  dims_inset_xlim: [7,9]  # Inset x-axis limits (step indices), null for auto (last 25% of steps)
  dims_inset_ylim: [0,30]  # Inset y-axis limits, null for auto

  # Legend locations (matplotlib location strings: "upper right", "upper left", "lower right", "lower left", "best", etc.)
  dims_legend_loc: "upper right"  # Legend location for dims plot (panel b)
  orthogonality_legend_loc: "lower right"  # Legend location for orthogonality plot (panel c)

  # Spacing between panels (fraction of subplot size, used by constrained_layout)
  wspace: 0.00  # Horizontal spacing (fraction)
  hspace: 0.00  # Vertical spacing (fraction)

  # Left 2x2 grid spacing (subgridspec)
  left_grid_wspace: 0.00  # Horizontal spacing between images
  left_grid_hspace: 0.00  # Vertical spacing between images

  # Belief regression grid spacing (subgridspec)
  belief_wspace: 0.00  # Horizontal spacing between Theory/Activations columns
  belief_hspace: 0.00  # Vertical spacing between factor rows
  belief_factor_label_x: -0.15  # X position of F labels (in axes coords, negative = left of Theory column)
  belief_factor_label_y: 0.5  # Y position of F labels (0.5 = vertically centered)
  belief_factor_label_ha: "right"  # Horizontal alignment of F labels (right, center, left)
  belief_anchor_theory: null  # Anchor for Theory column (N, S, E, W, C, NE, NW, SE, SW, or null)
  belief_anchor_activations: null  # Anchor for Activations column

  # Panel label positions [x, y] in axes coordinates
  label_a_pos: [-0.08, 1.10]  # (a) label for 2x2 grid
  label_b_pos: [-0.15, 1.12]  # (b) label for dims plot
  label_c_pos: [-0.15, 1.12]  # (c) label for orthogonality plot
  label_d_pos: [-0.35, 1.35]  # (d) label for belief regression

  # Axis label padding (reduces whitespace between stacked plots)
  middle_xlabel_labelpad: 1  # Padding for middle plot (b) xlabel
  bottom_xlabel_labelpad: 0  # Padding for bottom plot (c) xlabel (default matplotlib is 4.0)
