from train_unet import Config




def get_default_config():
    """Get default configuration for training"""
    config = Config()
    return config

def get_custom_config():
    """Get custom configuration with specific parameters"""
    config = Config()
    
    # Model parameters
    config.pretrained_model_name_or_path = "stable-diffusion-xl-1.0-inpainting-0.1"
    config.lora_rank = 16
    config.lora_alpha = 1.0
    
    # Training parameters
    config.learning_rate = 1e-6
    config.train_batch_size = 4
    config.num_train_epochs = 20
    config.resolution = 512
    
    # Data parameters
    config.meta_path = ["data/meta.json"]
    config.output_dir = "experiment/unet"
    
    return config 