# Experiment Metadata
equation: "poisson" #"darcy"
training_mode: "unified"
mode: "unified"
run_name: "pois_128_unified_guided_freq_comp_scalar_full_joint"

# Data Configuration
data: "datasets/DiffusionPDE/training/poisson/"
validate_data: "datasets/DiffusionPDE/testing/poisson.mat"
dataset: "diffusion_pde_poisson"
main_dataset: "DiffusionPDE"
sigma_data: 0.5  # Sigma data for EDM loss, used in EDM loss

# Model Configuration
cond: true
model_type: "SongUNOResidual"

model_channels: 64
num_blocks: 4
channel_mult: [1, 2, 4, 4]
resolution: 128 #64 #128 # resolution to train on, downsample if lower than original data res
train_downsample: true # downsample and train or train on original res
normalizer: "ScaledGaussian" # UnitGaussian or ScaledGaussian or ScaledGaussian2


# Sparse conditioning training (mask channel added)  
use_sparse_conditioning: false #false for no mask in arch or full obs
random_sample_masking: false    # Enable random sample masking    
enable_sparsity_curriculum: false 
enable_sample_curriculum: false

### Unified Training Task Probabilities
unified_task_prob_full_fwd: 0.5
unified_task_prob_full_inv: 0.5
unified_task_prob_sparse_fwd: 0.0
unified_task_prob_sparse_inv: 0.0
unified_task_prob_uncond: 0.0
# sparse_obs_range_start: 0.01
# sparse_obs_range_end: 0.5

# Neural Operator Configuration
spectral_conv: "tucker" # standard
fmult: 1.0
rank: 0.1
noise_src: "grf" # grf, gauss
rbf_scale: 0.05 # scale for the RBF kernel

# PDE Residual config
guided_pde_residual_mode: true  # Use ground truth to guide PDE residual computation in unified mode
pde_residual_mode: "freq_complex" #"concat" "freq_attn_real" "freq_attn" "freq_complex" "freq_mag"
pde_residual_gate_type: "scalar" # spatial scalar
pde_residual_step_mode: "one_step"
use_alpha: true
use_gating: true  # Enable gating for PDE residual
gating_mode: 1
normalize_pde_residual: true  # Normalize the PDE residual
spectral_inject_pos: "post"  # Position for spectral injection # post or pre
spatial_film_pos: "none"  


# Training Configuration
batch_gpu: 90 #90 #10 #90
duration: 20
seed: 33
workers: 4
lr: 1.0e-4  # Learning rate for the optimizer
lr_rampup: 5 # Learning rate ramp-up duration in epochs
use_fast_math: true  # Enable torch.backends.cudnn.allow_tf32 and torch.backends.cuda.matmul.allow_tf32


# Logging and Checkpointing
tick: 10  # Frequency of logging updates
snap: 50  # Frequency of model snapshots per tick
dump: 50  # Model dumping frequency per tick
pde_plot_ticks: 10  # How often to plot PDE residual 

#WandB Configuration
wandb_mode: online
wandb_project: exps_iclr
wandb_team: physics-diffusion