# ============================================================
# FADER NETWORK BASELINE CONFIG - MORPHO-MNIST
# ============================================================
# This is a self-contained configuration file.
# To run, use the baseline training script:
# python -m user_extensions.baselines.main --config configs/baselines/fader_morphomnist.yaml

# --- System to Train ---
system:
  name: "FaderSystem"

# --- Run Settings ---
run:
  sweep_name: "fader_baseline_morphomnist"
  log_dir: "../runs"
  num_workers: 4

# --- Data Settings ---
data:
  name: "morpho-mnist"
  batch_size: 256
  properties:
    mnist:
      dir: "<path_to_datasets>/mnist"
      image_shape: [1, 28, 28]
      num_classes: 10
      val_split_size: 10000
    morpho-mnist:
      dir: "<path_to_datasets>/morpho-mnist-global"
      image_shape: [1, 28, 28]
      num_classes: 10
      val_split_size: 10000
    celeba:
      dir: "<path_to_datasets>/CelebA"
      image_shape: [3, 64, 64]
      num_classes: 2


# --- Main Model (Encoder/Generator) Settings ---
model:
  type: "conv"
  latent_space:
    latent_dim: 16
    target_dim: 0 # Not used by Fader, but key must exist

  architecture:
    activation_type: "relu"
    use_bias: false
    conv:
      encoder:
        style: "vgg"
        h_dims: [32, 64]
        block_repeats: [1, 1]
        downsampling_method: "maxpool"
        downsampling_factor: 2
        use_bn: false
        mlp_h_units: [128]

      decoder:
        style: "vgg"
        h_dims: [64, 32]
        block_repeats: [1, 1]
        upsampling_method: "pixelshuffle"
        upsampling_factor: 2
        use_bn: false
        mlp_h_units: [128]

# --- Fader Discriminator Settings ---
fader_discriminator:
  architecture:
    h_units: [128, 128]
    use_bias: true
    activation_type: "leaky_relu"

# --- Training Loop Settings ---
training:
  epochs: 100
  optimizer:
    main:
      lr: 2.0e-4
    adversarial:
      lr: 2.0e-4
    betas: [0.5, 0.999]
    weight_decay: 1.0e-5

# --- Loss Function Settings ---
loss:
  recon_loss_type: "mse"
  lambda_schedule: 10000
  weights:
    gamma_rec: 1.0
    lambda_fader: 0.1

# --- Evaluation and Logging Settings ---
evaluation:
  log_interval: 50
  intervention_interval: 10
  cluster_method: "tsne"
  tsne_perplexity: 30

  metric_settings:
    dci_estimators: 50
    dci_max_depth: 10
    ssim_data_range: 2.0

  metrics:
    ssim:
      mode: [online]
    lpips:
      mode: [online]
    fid:
      mode: [post-hoc]
    BaselineSAP:
      mode: [online]
    AttributeInvarianceProbe:
      mode: [online]
    BaselineDCI:
      mode: [online]

  visualizations:
    BaselineClustering:
      mode: [online]
    FaderConditionalGeneration:
      mode: [online]


# --- Callbacks ---
callbacks:
  BaselineEvaluation:
    enabled: true

# --- Analysis ---
analysis:
  log_model_architecture: true
  graph_depth: 5