# file: configs/base.yaml
# ============================================================
# BASE CONFIGURATION
# ============================================================
# This file contains all default parameters for the framework.
# Experiment-specific YAML files can override these values.

# --- Run Settings ---
run:
  # Name for the experiment group or sweep, used for organizing logs.
  sweep_name: "morpho"
  # Root directory where all experiment logs and artifacts will be saved.
  log_dir: "../runs"
  # Number of worker processes for data loading.
  num_workers: 4

# --- Data Settings ---
data:
  # Name of the dataset to use. Must match a key in the DATASETS registry.
  # Options: 'mnist', 'morpho-mnist', 'yaleb', 'dsprites', 'celeba'
  name: "morpho-mnist"
  # Number of samples per batch.
  batch_size: 256

  # Dataset-specific properties. The section matching `data.name` will be merged.
  # `image_shape` and `num_classes` are automatically set from these properties.
  properties:

    mnist:
      dir: "<path_to_datasets>/mnist"
      image_shape: [1, 28, 28]
      num_classes: 10
      val_split_size: 10000

    morpho-mnist:
      dir: "<path_to_datasets>/morpho-mnist-global"
      image_shape: [1, 28, 28]
      num_classes: 10
      val_split_size: 10000

    celeba:
      dir: "<path_to_datasets>/CelebA"
      image_shape: [3, 64, 64]
      num_classes: 2


# --- Main Model (Encoder/Generator) Settings ---
model:
  # Primary architecture type for the autoencoder.
  # Options: 'mlp', 'conv', 'fcn' (Fully Convolutional)
  type: "conv"

  # Latent space dimensions.
  latent_space:
    # Total dimension of the latent space (z).
    latent_dim: 16
    # Dimension of the task-relevant subspace (z_1).
    # The residual subspace (z_0) will have dimension `latent_dim - target_dim`.
    target_dim: 8
    # Standard deviation of the Gaussian noise added to the encoder's output during training.
    # Set to 0.0 to make the encoder deterministic.
    encoder_noise_scale: 0.05

  # General architecture settings for all components.
  architecture:
    # Activation function for most layers.
    # Options: 'relu', 'leaky_relu', 'swish', 'gelu', 'elu', 'tanh'
    activation_type: "leaky_relu"
    # Whether to use bias terms in linear and convolutional layers.
    use_bias: true

    # MLP-specific architecture (used if model.type is 'mlp').
    mlp:
      encoder_h_dims: [256, 128]
      generator_h_dims: [128, 256]
      classifier_h_units: [64]
      latent_disc_h_units: [64]

    # Conv/FCN-specific architecture (used if model.type is 'conv' or 'fcn').
    conv:
      encoder:
        # Style of convolutional blocks.
        # Options: 'vgg', 'residual'
        style: "vgg"
        # List of channel counts for each downsampling block.
        h_dims: [32, 64]
        # Number of convolutional blocks to repeat at each resolution.
        block_repeats: [1, 1]
        # Method for downsampling.
        # Options: 'maxpool', 'avgpool', 'conv'
        downsampling_method: "maxpool"
        downsampling_factor: 2
        # Whether to use Batch Normalization.
        use_bn: false
        # Hidden units for the final MLP head (only for 'conv' type).
        mlp_h_units: [128]

      decoder:
        style: "vgg"
        h_dims: [64, 32]
        block_repeats: [1, 1]
        # Method for upsampling.
        # Options: 'pixelshuffle', 'convtranspose', 'nearest'
        upsampling_method: "pixelshuffle"
        upsampling_factor: 2
        use_bn: false
        # Hidden units for the initial MLP pre-processor (only for 'conv' type).
        mlp_h_units: [128]

  # FCN-specific spatial latent parameters (used if model.type is 'fcn').
  fcn_params:
    # Total number of channels in the spatial latent map.
    latent_channels: 16
    # Number of channels for the task-relevant subspace (z_1).
    target_channels: 8


# --- Adversarial Discriminator (Q-Network) Settings ---
discriminator_q:
  # Architecture type for the discriminator's feature extractor.
  # Options: 'mlp', 'conv'
  type: "conv"

  architecture:
    activation_type: "leaky_relu"
    use_bias: true

    # MLP-specific architecture for the discriminator backbone.
    mlp:
      encoder_h_dims: [256, 128]

    # Conv-specific architecture for the discriminator backbone.
    conv:
      encoder:
        style: "vgg"
        h_dims: [32, 64]
        block_repeats: [1, 1]
        downsampling_method: "maxpool"
        downsampling_factor: 2
        use_bn: false


# --- Training Loop Settings ---
training:
  epochs: 100
  optimizer:
    main:
      # Learning rate for the main models (Encoder, Generator, Classifier).
      lr: 3.0e-4
    adversarial:
      # Learning rate for adversarial models (DiscriminatorQ, LatentDiscriminator, PriorDiscriminator).
      lr: 3.0e-6
    # Adam/AdamW beta parameters.
    betas: [0.5, 0.999]
    # Weight decay for all optimizers.
    weight_decay: 1.0e-4

  gradient_clipping:
    enabled: true
    # Clipping algorithm. Options: 'norm', 'value'
    algorithm: "norm"
    # Clip value for the main optimizer.
    main_clip_val: 10.0
    # Clip value for the adversarial optimizer.
    adversarial_clip_val: 10.0


# --- Loss Function Settings ---
loss:
  # Type of reconstruction loss.
  # Options: 'mse', 'l1', 'vgg_single_layer', 'vgg_multi_layer'
  recon_loss_type: "mse"
  # Type of GAN loss for the image discriminator.
  # Options: 'bce', 'softplus'
  gan_loss_type: "softplus"

  # Settings for VGG-based perceptual loss.
  vgg:
    # Set to true if generator output is in [-1, 1] (uses Tanh).
    input_is_tanh: true
    # Loss metric for comparing VGG features. Options: 'l1', 'mse'
    loss_type: "l1"
    # VGG model to use for feature extraction. Options: 'vgg16', 'vgg19'
    model_type: "vgg19"
    # Index of the layer to use for 'vgg_single_layer' loss.
    layer: 15
    # Layer indices and weights for 'vgg_multi_layer' loss.
    # layers_and_weights:
    #   3: 0.25
    #   8: 0.25
    #   15: 0.25
    #   24: 0.25

  # Weights for each component of the total loss function.
  weights:
    gamma_rec: 1.0      # Reconstruction loss
    gamma_cls: 1.0      # Supervised classification loss on z_1
    gamma_l: 0.1        # Latent discriminator loss (adversarial on z_0)
    gamma_gan: 1.0      # Image discriminator loss (adversarial on x_rec)
    gamma_info: 0.01    # InfoGAN-style mutual information loss on z_0
    gamma_proto: 0.01   # Prototypical loss on z_1
    gamma_prior: 0.01    # AAE-style prior matching loss on z_0

  # Momentum for updating class prototypes (Exponential Moving Average).
  prototype_momentum: 0.99
  # Standard deviation of noise added to reconstructions before feeding to Q-network.
  noise_std_for_q: 0.01

  # R1 gradient penalty settings for the image discriminator.
  r1_penalty:
    # Weight of the R1 penalty.
    gamma_r1: 10.0
    # How often (in global steps) to apply the penalty.
    interval: 16
    # Number of steps for linear warmup of the R1 penalty.
    warmup_steps: 100
    # Alternative: Number of epochs for warmup (set `warmup_steps` to 0 to use).
    warmup_epochs: 0

# --- Evaluation and Logging Settings ---
evaluation:
  # Log metrics every N global steps. Also controls frequency of training-time metrics.
  log_interval: 25
  # Run expensive online visualizations and save artifacts every N epochs.
  intervention_interval: 10
  # Number of unique samples/batches to use for sample-dependent visualizations.
  num_visualization_samples: 1

  # Settings for clustering and latent traversal visualizations.
  cluster_method: "tsne" # 'pca', 'tsne'
  pca_components: 3
  traversal_steps: 9
  traversal_std: 2.5
  tsne_perplexity: 30

  # Specific settings for various metrics.
  metric_settings:
    mig_bins: 20
    dci_estimators: 50
    dci_max_depth: 10
    ssim_data_range: 2.0  # For images in [-1, 1] range. Use 1.0 for [0, 1].
    style_cluster_k: 10

  # Configuration for which metrics to run and when.
  # `mode` can be a string or a list of strings:
  #   'online': during the validation epoch end.
  #   'post-hoc': after training, via the 'evaluate' command.
  #   'training': periodically during the training loop (can be slow).
  metrics:
    ssim:
      mode: online
    lpips:
      mode: post-hoc
    fid:
      mode: post-hoc
    dci:
      mode: online
    sap:
      mode: online
    mig:
      mode: online
    linear_probe:
      mode: online
    identity_ari:
      mode: online
    identity_silhouette:
      mode: online
    style_leakage_ari:
      mode: online
    style_silhouette:
      mode: online
    identity_nmi:
      mode: online
    style_leakage_nmi:
      mode: online

  # Configuration for which visualizations to generate and when.
  visualizations:
    intervention:
      mode: online
    subspace_replacement:
      mode: online
    latent_traversal:
      mode: online
    clustering:
      mode: online
    correlation:
      mode: post-hoc


# --- Callback Settings ---
callbacks:
  # Saves model weights and latent vectors during validation/testing.
  ArtifactSaver:
    enabled: true
  # Computes and logs metrics during training and validation.
  Metrics:
    enabled: true
  # Generates and saves visualizations marked as 'online' during validation.
  OnlineVisualizations:
    enabled: true


# --- Model Analysis Settings ---
analysis:
  # If true, logs model summaries and graphs to the run directory on startup.
  log_model_architecture: true
  # Max depth for printing model summaries and drawing graphs.
  graph_depth: 10