import os
import argparse

parser = argparse.ArgumentParser(description='Training Configuration')

# General arguments
parser.add_argument('--gpu', type=int, default=0, help='gpu id')
parser.add_argument('--seed', type=int, default=10, help='Random seed.')
parser.add_argument('--dataset_name', type=str, default='celeba256', help='Environment name.', choices=['celeba256']) 
parser.add_argument('--batch_size', type=int, default=64, help='Mini batch size.')
parser.add_argument('--sampling_batch_size', type=int, default=50)

parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema")  # Choice doesn't affect training

# Model configuration
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate.')
parser.add_argument('--weight_decay', type=float, default=0.1, help='Weight decay.')
parser.add_argument('--denoise_timesteps', type=int, default=128, help='Denoising timesteps.')
parser.add_argument('--bootstrap_every', type=int, default=8, help='Bootstrap interval (divisor of batch size).')
parser.add_argument('--append_bst_batch', action='store_true', help="if true, appends the bootstrap targets to the batch, otherwise replaces the original batch elements")
parser.add_argument('--train_type', type=str, default='shortcut', choices=['ST', 'ST-CSL'],
                    help='Training type.')
#-----------------------------------------------------------------------------------------------------------------------------
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
parser.add_argument('--model', type=str, default='DiT-B-2')
parser.add_argument("--epochs", type=int, default=501)
parser.add_argument('--ckpt_every', type=int, default=100, help='Save checkpoints at regular intervals')

def get_args():
    return parser.parse_args()
