from dataclasses import dataclass

@dataclass
class TrainingConfig:
    sequence_length = 256  
    num_epochs = 10000
    lr_warmup_steps = 100
    gradient_accumulation_steps = 1
    learning_rate = 1e-3
    seed = 0
    print_every = 1000  # how often to print the loss
    channels = 16  # number of channels
    input_channels = 1
    num_blocks = 4  # number of DiT blocks
    class_dropout_prob = 0.0
    output_dir = f'diffbatt_{channels}_{num_blocks}_{class_dropout_prob}_{num_epochs}'  