# Run from base dir as:
# >> python -m floral.main task@_global_=<TASK>
# check the task directory to see what tasks are available.
#
# To check how the config is resolved:
# >> python -m floral.main --cfg all --resolve
#
---
defaults:
  - _self_
  - method@_global_: floral   # default training method package
  - task@_global_: synthetic_linear  # task package (NOTE: can override default method config)
  - extras@_global_: none  # this package adds flexibility to mix extra cfgs with some tasks and methods
  - override hydra/job_logging: disable_hydra_logging  # use flwr logger instead

# Use '/scratch/$USER' as IO dir if detected as $SCRATCH, otherwise use current dir
scratch_dir: ${oc.env:SCRATCH,.}
data_dir: ${oc.env:DATA_DIR,${scratch_dir}}/data
# Putting output_dir in /scratch introduces some weird behavior in sweeps due to
# cluster config (in our cluster, each partition has its own /scratch.)
output_dir: ${oc.env:OUTPUT_DIR,.}/outputs
# If False, sweeps will terminate if a history file with the same name already exists
overwrite_sweep: False
# If True, all sweep run files except history will be cleared after completion of run
clear_sweep_run_files: True

# High-level args (should be overriden from command line)
experiment: experiment
identifier: identifier
task_dir: "id=${identifier}"

# Simple hydra configs, check conf/hydra for more
hydra:
  run:
    dir: ${output_dir}/${experiment}/${task_dir}
  sweep:
    dir: ${output_dir}/${experiment}
    subdir: ${hydra.job.override_dirname}
  sweeper:
    params: ???  # should be filled by task/sweeper package

# These should be filled at runtime (do not change)
task: ???
method: ???
extras: ???
logdir: ???

# Logger confid args
max_logfiles: 3
loglevel: DEBUG

# helper args
show_cfg: False
wandb: False

# Task-specific args (most should be overridden by task and method packages)
num_rounds: 10
round_timeout: null
local_epochs: 1
deterministic: True
seed: 0
dataloader:
  num_workers: 0
batch_size: 32
test_batch_size: 512
lr: 0.1
weight_decay: 0.0
continue_training: False

# Data split args
train_proportion: 0.8  # for synthetic datasets
val_proportion: 0.0  # for synthetic and flwr datasets
validation_mode: False  # for tff datasets (val split is predetermined)
unseen_clients: 0.0  # a ratio or exact number of testing clients (unseen during training)

# Allows for extra user-specified private_modules without breaking the original method
extra_private_modules: []

# Method-dependent model instantiation is in floral.training.utils
model: ???

# Datasets are in floral.dataset (mainly handled by task package)
dataset: ???

# Method-dependent param groups are defined in floral.training.utils
# NOTE: Requires param groups for instantiation.
optimizer:
  _target_: torch.optim.SGD
  lr: ${lr}
  weight_decay: ${weight_decay}

global_optimizer:
  _target_: torch.optim.SGD
  lr: 1.0

# Loss functions should be declared in the config
loss_fn:
  _target_: torch.nn.MSELoss

# This is an empty regularizer as a template (see conf/method/floral for an example)
regularizer:
  _target_: floral.training.utils.Regularizer
  regularizers: {}

# This is a helper interface class that contains all the training-specific methods
# Base trainer's args are often defined at the global level.
trainer:
  _target_: floral.training.Trainer
  local_epochs: ${local_epochs}
  loss_fn: ${loss_fn}
  regularizer: ${regularizer}

# Default Flower Client
client:
  _target_: floral.client.FlowerClient

# Default Flower Strategy with weighted average metrics aggregation
strategy:
  _target_: floral.server.strategy.FedAvgBase
  fraction_fit: 1.0
  fraction_evaluate: 1.0
  evaluate_metrics_aggregation_fn:
    _target_: floral.server.strategy.get_metrics_aggregation_fn
  fit_metrics_aggregation_fn:
    _target_: floral.server.strategy.get_metrics_aggregation_fn
  # Opts for our version of FedAvg, which should be filled at runtime
  global_model: ???
  global_optimizer: ???
  save_path: ???
  continue_training: ${continue_training}

# It is important to choose these carefully depending on your system's resources and memory
client_resources:
  num_cpus: 1
  num_gpus: 0
  # XXX
  # num_cpus: 4
  # num_gpus: 1
  # XXX

ray_init_args:
  # Don't connect to an existing ray cluster, create a new one instead
  # If you connect to an existing Ray cluster, then flower will create as many actors as can fit (given resources).
  # This is fine as long as the Ray cluster was created ONLY for this run. Otherwise, it will be messy,
  # and Ray will need to manage clients from different runs simulatneously with more actors than necessary.
  address: local
  # Below are the default `ray_init_args` from Flower source code
  ignore_reinit_error: True
  include_dashboard: False

keep_ray_initialized: False
