default: &DEFAULT

  #General
  # For computing compression
  n_params_baseline: None #If None, will be computed
  verbose: True
  arch: 'fno2d'

  #Distributed computing
  distributed:
    use_distributed: False
    wireup_info: 'mpi'
    wireup_store: 'tcp'
    model_parallel_size: 2
    seed: 666

  # FNO related
  fno2d:
    data_channels: 3
    n_modes_height: 24
    n_modes_width: 24
    hidden_channels: 64
    projection_channel_ratio: 4
    n_layers: 4
    domain_padding: 0.078125
    domain_padding_mode: 'one-sided' #symmetric
    fft_norm: 'forward'
    norm: None
    skip: 'soft-gating'
    implementation: 'reconstructed'
    
    use_channel_mlp: 1
    channel_mlp_expansion: 0.5
    channel_mlp_dropout: 0

    separable: False
    factorization: None
    rank: 1.0
    fixed_rank_modes: None
    dropout: 0.0
    tensor_lasso_penalty: 0.0
    joint_factorization: False
    stabilizer: None # or 'tanh'

  # Optimizer
  opt:
    n_epochs: 500
    learning_rate: 1e-3
    training_loss: 'h1'
    weight_decay: 1e-4
    amp_autocast: False

    scheduler_T_max: 500 # For cosine only, typically take n_epochs
    scheduler_patience: 5 # For ReduceLROnPlateau only
    scheduler: 'StepLR' # Or 'CosineAnnealingLR' OR 'ReduceLROnPlateau'
    step_size: 100
    gamma: 0.5

  # Dataset related
  data:
    # folder: '/home/nikola/HDD/NavierStokes/2D'
    folder: /data/navier_stokes/
    # folder: '/data'
    batch_size: 16
    n_train: 10000
    train_resolution: 128
    n_tests: [2000, 1000] #, 1000]
    test_resolutions: [128, 1024] #, 1024] 
    test_batch_sizes: [16, 4] #, 1]
    encode_input: True
    encode_output: False
    num_workers: 0
    pin_memory: False
    persistent_workers: False

  # Patching
  patching:
    levels: 0 #1
    padding: 0 #0.078125
    stitching: True

  # Weights and biases
  wandb:
    log: True
    name: None # If None, config will be used but you can override it here
    group: 'super-resolution' 
    project: "Refactored-TFNO"
    entity: "nvr-ai-algo" # put your username here
    sweep: False
    log_output: True
    eval_interval: 1

original_fno:
  arch: 'tfno2d'

  fno2d:
    modes_height: 64
    modes_width: 64
    width: 64
    hidden_channels: 256
    n_layers: 4
    domain_padding: 0.078125
    domain_padding_mode: 'one-sided'
    fft_norm: 'forward'
    norm: None
    skip: 'linear'
    
    use_channel_mlp: 0
    channel_mlp:
    channel_mlp_expansion: 0.5
    channel_mlp_dropout: 0

    separable: False
    factorization: None
    rank: 1.0
    fixed_rank_modes: None

  wandb:
    log: False
    name: None # If None, config will be used but you can override it here
    group: 'wandb_group'
    
  
distributed_mg_tucker:
  tfno2d:
    factorization: Tucker
    compression: 0.42
    domain_padding: 9

  distributed:
    use_distributed: True
    wireup_info: 'mpi'
    wireup_store: 'tcp'
    model_parallel_size: 2
    seed: 666

  patching:
    levels: 1
    padding: 16
    stitching: True

