# file: configs/celeba_config.yaml
# ============================================================
# CONFIGURATION FOR CELEBA
# ============================================================
# This configuration is optimized for training on the CelebA dataset.
# It uses a Fully Convolutional Network (FCN) with residual blocks and
# a perceptual VGG loss, which are well-suited for high-resolution images.

# --- Run Settings ---
run:
  # A unique name for this experiment group, used for organizing logs.
  sweep_name: "celeba"
  # Root directory for saving all experiment outputs.
  log_dir: "../runs"
  # Number of CPU workers for the data loader. Higher is better for large datasets.
  num_workers: 8


# --- Data Settings ---
data:
  # Dataset name. Must match a registered datamodule.
  name: "celeba"
  # Number of samples per batch. Adjust based on GPU memory.
  batch_size: 64

  # Dataset-specific properties. These override any defaults.
  properties:
    celeba:
      # Path to the root directory of the CelebA dataset.
      dir: "<path_to_datasets>/CelebA"
      # Target image shape [Channels, Height, Width].
      image_shape: [3, 128, 128]
      # Number of target classes (e.g., 2 for Smiling/Not Smiling).
      num_classes: 2


# --- Main Model (Encoder/Generator) Settings ---
model:
  # Model architecture type. 'fcn' is good for spatial tasks.
  # Options: 'mlp', 'conv', 'fcn'
  type: "fcn"

  # Latent space dimensions.
  latent_space:
    # Total dimension of the flattened latent space (z).
    latent_dim: 32
    # Dimension of the task-relevant subspace (z_1).
    target_dim: 16
    # Standard deviation of the Gaussian noise added to the encoder's output during training.
    # Set to 0.0 to make the encoder deterministic.
    encoder_noise_scale: 0.05

  # FCN-specific spatial latent parameters (used because model.type is 'fcn').
  fcn_params:
    # Total channels in the spatial latent map before flattening.
    latent_channels: 16
    # Channels for the task-relevant subspace (z_1) in the spatial map.
    target_channels: 4

  # General architecture settings for all components.
  architecture:
    # Activation function. 'leaky_relu' is often a good choice for GANs.
    # Options: 'relu', 'leaky_relu', 'swish', 'gelu'
    activation_type: "leaky_relu"
    # Whether to use bias terms in layers.
    use_bias: true

    # MLP-specific architecture (unused here, but kept for schema consistency).
    mlp:
      encoder_h_dims: [256, 128]
      generator_h_dims: [128, 256]
      classifier_h_units: [64]
      latent_disc_h_units: [64]

    # Conv/FCN-specific architecture.
    conv:
      encoder:
        # Block style. 'residual' often trains better for deeper networks.
        # Options: 'vgg', 'residual'
        style: "residual"
        # List of channel counts for each downsampling block. Defines network depth and capacity.
        h_dims: [32, 64, 128, 128, 256]
        # Number of blocks at each resolution.
        block_repeats: [1, 1, 1, 1, 1]
        # Downsampling method. 'conv' (strided convolution) can be more stable than pooling.
        # Options: 'maxpool', 'avgpool', 'conv'
        downsampling_method: "conv"
        downsampling_factor: 2
        use_bn: true
        mlp_h_units: [128] # Unused for 'fcn' type.

      decoder:
        style: "residual"
        h_dims: [256, 128, 128, 64, 32]
        block_repeats: [1, 1, 1, 1, 1]
        # Upsampling method. 'nearest' neighbor + conv can avoid checkerboard artifacts.
        # Options: 'pixelshuffle', 'convtranspose', 'nearest'
        upsampling_method: "nearest"
        upsampling_factor: 2
        use_bn: true
        mlp_h_units: [128] # Unused for 'fcn' type.


# --- Adversarial Discriminator (Q-Network) Settings ---
discriminator_q:
  # Architecture type for the discriminator's feature extractor.
  type: "conv"

  architecture:
    activation_type: "leaky_relu"
    use_bias: true

    # MLP-specific architecture (unused).
    mlp:
      encoder_h_dims: [256, 128]

    # Conv-specific architecture.
    conv:
      encoder:
        style: "residual"
        h_dims: [32, 64, 128, 128, 256]
        block_repeats: [1, 1, 1, 1, 1]
        downsampling_method: "conv"
        downsampling_factor: 2
        use_bn: false


# --- Training Loop Settings ---
training:
  epochs: 100
  optimizer:
    main:
      lr: 3.0e-4
    adversarial:
      lr: 3.0e-6
    betas: [0.5, 0.999]
    weight_decay: 1.0e-4

  gradient_clipping:
    enabled: true
    algorithm: "norm"
    main_clip_val: 10.0
    adversarial_clip_val: 10.0


# --- Loss Function Settings ---
loss:
  # Reconstruction loss. 'vgg_multi_layer' is a perceptual loss, good for complex images.
  # Options: 'mse', 'l1', 'vgg_single_layer', 'vgg_multi_layer'
  recon_loss_type: "vgg_multi_layer"
  # GAN loss. 'softplus' is a common, stable alternative to BCE.
  # Options: 'bce', 'softplus'
  gan_loss_type: "softplus"

  # Settings for VGG-based perceptual loss.
  vgg:
    # Set to true if generator output is in [-1, 1] (uses Tanh).
    input_is_tanh: true
    # Metric for comparing VGG features. Options: 'l1', 'mse'
    loss_type: "l1"
    # VGG model to use. Options: 'vgg16', 'vgg19'
    model_type: "vgg19"
    # Index of the layer to use for 'vgg_single_layer' loss (unused here).
    layer: 15
    # Layer indices and weights for 'vgg_multi_layer' loss.
    layers_and_weights:
      3: 0.25
      8: 0.25
      15: 0.25
      24: 0.25

  # Weights for each component of the total loss function.
  weights:
    gamma_rec: 1.0      # Reconstruction
    gamma_cls: 1.0      # Supervised classification on z_1
    gamma_l: 0.1        # Latent discriminator on z_0
    gamma_gan: 1.0      # Image discriminator on x_rec
    gamma_info: 0.01    # InfoGAN mutual information on z_0
    gamma_proto: 0.1    # Prototypical loss on z_1
    gamma_prior: 0.01    # AAE-style prior matching loss on z_0

  prototype_momentum: 0.99
  noise_std_for_q: 0.01

  # R1 gradient penalty settings for the image discriminator.
  r1_penalty:
    gamma_r1: 10.0
    interval: 1
    warmup_steps: 100
    warmup_epochs: 0


# --- Evaluation and Logging Settings ---
evaluation:
  log_interval: 25
  intervention_interval: 10
  num_visualization_samples: 3

  # Settings for clustering and latent traversal visualizations.
  cluster_method: "tsne"
  pca_components: 3
  traversal_steps: 9
  traversal_std: 2.5
  tsne_perplexity: 30

  # Specific settings for various metrics.
  metric_settings:
    mig_bins: 20
    dci_estimators: 50
    dci_max_depth: 10
    ssim_data_range: 2.0
    style_cluster_k: 10

  # Configuration for which metrics to run and when.
  # `mode` can be a string or a list of strings:
  #   'online': during the validation epoch end.
  #   'post-hoc': after training, via the 'evaluate' command.
  #   'training': periodically during the training loop (can be slow).
  metrics:
    ssim:
      mode: online
    lpips:
      mode: online
    fid:
      mode: online
    dci:
      mode: online
    sap:
      mode: online
    mig:
      mode: online
    linear_probe:
      mode: online
    identity_ari:
      mode: online
    identity_silhouette:
      mode: online
    style_leakage_ari:
      mode: online
    style_silhouette:
      mode: online
    identity_nmi:
      mode: online
    style_leakage_nmi:
      mode: online

  # Configuration for which visualizations to generate and when.
  visualizations:
    intervention:
      mode: online
    subspace_replacement:
      mode: online
    latent_traversal:
      mode: online
    clustering:
      mode: online
    correlation:
      mode: post-hoc


# --- Callback Settings ---
callbacks:
  # Saves model weights and latent vectors.
  ArtifactSaver:
    enabled: true
  # Computes online metrics.
  Metrics:
    enabled: true
  # Generates online visualizations.
  OnlineVisualizations:
    enabled: true


# --- Model Analysis Settings ---
analysis:
  log_model_architecture: true
  graph_depth: 10