# Model and Experiment Identification
model_name: 'mvaug'  # Name of the target model for training
logging_dir: logs  # Directory to store training logs
output_dir: /mnt/your_path/logs/mvaug/  # Path for saving model checkpoints and outputs
report_to: None  # Disable external reporting tools (e.g., Weights & Biases)
tracker_name: cosmos2_trainer  # Name identifier for the training tracker

# Training Termination Criteria
# Total training steps are determined by the smaller value between train_steps and train_epochs
train_steps: 1000000  # Maximum number of training steps
train_epochs: 10000   # Maximum number of training epochs

# Training Monitoring and Saving Frequency
steps_to_save: 1000    # Save model checkpoint every N steps
steps_to_log: 5        # Log training metrics (loss, lr, etc.) every N steps
steps_to_val: 1000     # Run validation every N steps

# Mixed Precision Training Configuration
mixed_precision: bf16  # Use BF16 mixed precision for training (balances speed and precision)
allow_tf32: False      # Disable TF32 precision (for numerical stability)

# Distributed Training Settings
nccl_timeout: 600  # Timeout (in seconds) for NCCL communication (prevents hanging in distributed mode)
seed: 42           # Random seed for reproducibility of training results

# Prompt Embedding and Classifier-Free Guidance
prompt_emb_mode: online  # Generate prompt embeddings dynamically during training
train_w_cfg: false       # Disable dropout for prompts (used for classifier-free guidance if enabled)
caption_dropout_p: 0.2   # Probability of dropping captions (for data augmentation if train_w_cfg=True)
load_val: True           # Load validation dataset at the start of training

# VAE (Variational Autoencoder) Optimization
enable_slicing: True  # Slice VAE forward passes to reduce memory usage
enable_tiling: True   # Use tiling for VAE operations (further reduces memory footprint)

# Transformer Model Configuration
transformer:
  # Pretrained/Resumed Model Path (uncomment to use)
  # model_path: '/mnt/your_path/huggingface/models/Cosmos-Predict2-2B-Video2World/transformer/diffusion_pytorch_model.safetensors'
  # model_path: '/mnt/your_path/logs/mvaug/step_19000/diffusion_pytorch_model.safetensors'
  model_path: 'mnt/your_path/logs/mvaug/step_12000/diffusion_pytorch_model.safetensors'
  
  # Transformer Architecture Hyperparameters
  config:
    in_channels: 18               # Input channels: 16 (VAE latent) + 1 (Canny edge) + 1 (padding mask)
    out_channels: 16              # Output channels (matches VAE latent dimension)
    num_attention_heads: 16       # Number of heads in multi-head attention
    attention_head_dim: 128       # Dimension per attention head
    num_layers: 28                # Number of transformer layers
    mlp_ratio: 4.0                # Ratio of MLP hidden dimension to input dimension (4x input dim)
    text_embed_dim: 1024          # Dimension of text prompt embeddings
    adaln_lora_dim: 256           # Dimension of LoRA for AdaLN layers (if LoRA is enabled)
    max_size: [128, 240, 240]     # Maximum input shape: [frames, height, width]
    patch_size: [1, 2, 2]         # Patch size for spatial-temporal tokenization: [frame, height, width]
    rope_scale: [1.0, 3.0, 3.0]   # Scaling factors for Rotary Position Embedding (RPE) across [T, H, W]
    concat_padding_mask: true     # Concatenate padding mask to input (functionality under verification)
    extra_pos_embed_type: null    # Disable additional position embeddings
    use_view_embed: true          # Enable view-specific embeddings (for multi-camera inputs)

# Dataset and Data Loading Configuration
data:
  # Training Dataset Settings
  train:
    jsonl_path_list: ['/mnt/your_path/data/sl_cloth_meta.jsonl', '/mnt/your_path/a2d_data_process/sl_meta/sl_box.jsonl']  # Paths to metadata JSONL files
    video_folder_list: ['/mnt/your_path/data/sl_cloth', '/mnt/your_path/a2d_data_mp4']  # Paths to raw video directories
    dataset_name_list: ['a2d','a2d']  # Dataset identifiers (for consistent data processing)
    dataset_source_list: ['processed','processed']  # Indicate preprocessed data
    cam_use: [['head','left_hand','right_hand'],['head','left_hand','right_hand']]  # Cameras to use (head + two hand-mounted cameras)
    
    # Data Preprocessing and Sampling Hyperparameters
    n_view: 3                     # Number of camera views to sample per data point
    sample_size: [384, 512]       # Resize video frames to [height, width]
    sample_n_frames: 360          # Number of frames to sample per video
    preprocess: 'resize'          # Preprocessing method (resize to fixed size)
    pad_mode: 'first'             # Padding strategy: pad with the first frame if frame count is insufficient
    chunk: 25                     # Chunk size for frame-wise processing (reduces memory usage)
    fps: 30                       # Target FPS for video resampling (matches real-world capture FPS)
    camera_pose: False            # Disable camera pose input (not used in this experiment)
    n_previous: 4                 # Number of previous frames to include as context
    previous_pick_mode: 'random'  # Strategy to select previous frames (random sampling)
    random_crop: false            # Disable random cropping (use fixed resize for consistency)
    use_adc: true                 # Enable adaptive data caching (speeds up data loading)
    random_padding: false         # Disable random padding (use deterministic padding)
    use_recaption: false          # Disable automatic recaptioning (use original captions)
    task_recap_file: '/mnt/your_path/task_recap.json'  # Path to task-level caption templates
    step_recap_file: '/mnt/your_path/step_recap.json'  # Path to step-level caption templates
    max_stride: 5                 # Maximum stride for frame sampling (avoids redundant frames)
    fixed_anchor_view: true       # Use fixed anchor view for multi-view alignment
    ignore_seek: false            # Do not skip frames during sampling
    filter_action: false          # Disable action filtering (use all action classes)
    traj_add_pc: False            # Disable point cloud trajectory input (not used in this experiment)
    use_unified_prompt: True      # Use unified prompt template across datasets
    acwm: false                   # Disable action-conditioned motion modeling (not required for augmentation model)

  # Validation Dataset Settings
  val:
    jsonl_path_list: ['/mnt/your_path/data/sl_cloth_meta.jsonl', '/mnt/your_path/a2d_data_process/sl_meta/sl_box.jsonl']
    video_folder_list: ['/mnt/your_path/data/sl_cloth', '/mnt/your_path/a2d_data_mp4']
    dataset_name_list: ['a2d','a2d']
    dataset_source_list: ['processed','processed']
    cam_use: [['head','left_hand','right_hand'],['head','left_hand','right_hand']]
    
    # Same preprocessing as training (ensures consistency)
    n_view: 3
    sample_size: [384, 512]
    sample_n_frames: 360
    preprocess: 'resize'
    pad_mode: 'first'
    chunk: 25
    fps: 30
    camera_pose: False
    n_previous: 4
    previous_pick_mode: 'random'
    random_crop: false
    use_adc: true
    fix_timestep: false           # Disable fixed timestep sampling (use random timesteps for validation)
    use_recaption: false
    task_recap_file: '/mnt/your_path/task_recap.json'
    step_recap_file: '/mnt/your_path/step_recap.json'
    max_stride: 5
    fixed_anchor_view: true
    ignore_seek: false
    filter_action: false
    traj_add_pc: False
    use_unified_prompt: True
    acwm: false

# DataLoader Configuration
batch_size: 4                    # Batch size per GPU (adjust based on GPU memory)
dataloader_num_workers: 2        # Number of worker processes for data loading
pin_memory: True                 # Pin memory to speed up data transfer to GPU

# LoRA (Low-Rank Adaptation) Configuration (disabled for full fine-tuning)
train_type: 'full'               # Training type: 'full' (full fine-tuning) or 'lora' (LoRA)
target_modules: "all-linear"     # Target modules for LoRA (only active if train_type='lora')
rank: 256                        # LoRA rank (only active if train_type='lora')
lora_alpha: 256                  # LoRA scaling factor (only active if train_type='lora')
prev_checkpoint: ~               # No previous LoRA checkpoint (start from scratch)

# Training Optimization
gradient_checkpointing: True     # Enable gradient checkpointing (reduces memory usage)
noise_to_first_frame: 0.2        # Noise injection ratio for the first frame (augmentation)
wo_hand_cond: False              # Enable hand pose conditioning (not disabled)
wo_cap_prompt: False             # Enable caption prompt conditioning (not disabled)

# Optimizer Settings (AdamW)
optimizer: adamw                 # Use AdamW optimizer (standard for transformer training)
lr: 1e-4                         # Initial learning rate
beta1: 0.9                       # AdamW beta1 parameter (momentum)
beta2: 0.95                      # AdamW beta2 parameter (second-moment)
beta3: 0.999                     # Additional beta parameter for stability
epsilon: 1e-8                    # Epsilon for numerical stability
weight_decay: 5e-5               # Weight decay (regularization to prevent overfitting)
optimizer_8bit: True             # Use 8-bit optimizer (reduces memory usage)
optimizer_torchao: False         # Disable TorchAO optimizer (use standard AdamW)
scale_lr: False                  # Disable learning rate scaling with batch size

# Gradient Clipping
max_grad_norm: 1.0               # Maximum gradient norm for clipping (prevents exploding gradients)
gradient_accumulation_steps: 1   # No gradient accumulation (use full batch size per step)

# Learning Rate Scheduler
lr_scheduler: constant_with_warmup  # Constant LR with linear warmup
lr_warmup_steps: 1000              # Number of warmup steps (gradually increase LR to 1e-4)
lr_num_cycles: 1                   # Number of cosine cycles (not used for constant scheduler)
lr_power: 1.0                      # Power factor for polynomial scheduler (not used here)

# Timestep Configuration (for diffusion model)
flow_weighting_scheme: none       # No weighting scheme for timesteps
flow_logit_mean: 0.0              # Mean of logits for flow modeling
flow_logit_std: 1.0               # Std of logits for flow modeling
flow_mode_scale: 1.29             # Scaling factor for flow mode

# Inference and Augmentation
train_mode: 'video_only'         # Training mode: video-only input (no additional modalities)
use_color_jitter: 0.3            # Color jitter intensity (augmentation)
num_inference_step: 35           # Number of diffusion steps for inference

# Validation Settings
val_shuffle: True                # Shuffle validation data (prevents order bias)

# DeepSpeed Configuration (Distributed Training)
use_deepspeed: true              # Use DeepSpeed for distributed training
deepspeed:
  zero_optimization:
    stage: 2                     # ZeRO Stage 2 (shards gradients and optimizer states)
    # offload_optimizer:
    #   device: cpu               # Uncomment to offload optimizer to CPU (further reduces GPU memory)
  fp16:
    enabled: false               # Disable FP16 (use BF16 instead)
  bf16:
    enabled: true                # Enable BF16 mixed precision (consistent with earlier config)
