# 路径
root_dir = 'C:\\gkw\\experiments\\cv_demo_cvae_fan_ppp_mask\\'
output_dir = root_dir + 'output\\'


# 数据集参数设置
def get_data_config(data_name):
    config = {}
    if data_name == 'cifar10':
        # config['in_channels'], config['num_classes'], config['batch_size'] = 3, 10, 32
        # config['data_dir'] = 'C:\\gkw\\experiments\\cifar10\\images'
        # config['train_dir'] = config['data_dir'] + '\\train'
        # config['test_dir'] = config['data_dir'] + '\\test'
        print("None")
    else:
        config['in_channels'], config['num_classes'], config['batch_size'] = 1, 2, 256
        config['data_dir'] = 'C:\\gkw\\dataset\\fan'
        config['train_dir'] = 'C:\\gkw\\fantest\\train_0'
        config['test_dir'] = 'C:\\gkw\\pumptest\\test'
        config['align_dir'] = 'C:\\gkw\\pumptest\\train_0'
    return config


# 训练参数
lr=0.001  
momentum=0.9
weight_decay=5e-4
num_epochs=200
