import torch
import logging
import os.path as osp
from datetime import datetime
from easydict import EasyDict

cfg = EasyDict(__name__='Config: baseline')

cfg.root_dir = 'data/cifar10' # The folder of images used for training
cfg.list_file = 'data/cifar10/cifar10.txt' # a text file to list the paths of all images

cfg.resolution = 32
cfg.min_crop = 0.5
cfg.max_crop = 1.0

# dataloader
cfg.batch_size = 24
cfg.num_workers = 8
cfg.prefetch_factor = 2
cfg.seed = 6666

# diffusion
cfg.schedule = 'linear'
cfg.num_timesteps = 1000
cfg.mean_type = 'eps'
cfg.var_type = 'fixed_small'
cfg.loss_type = 'mse'
cfg.clamp = 1.0

# unet
cfg.unet = EasyDict()
cfg.unet.in_dim = 3
cfg.unet.dim = 128
cfg.unet.out_dim = 3
cfg.unet.dim_mult = [1,1,2,2,4,4]
cfg.unet.num_heads = None
cfg.unet.head_dim = 64
cfg.unet.dim_scale = 4
cfg.unet.num_res_blocks = 2
cfg.unet.attn_scales = [1/8, 1/16, 1/32]
cfg.unet.dropout = 0
cfg.unet.use_checkpoint = False
cfg.unet.use_scale_shift_norm = True

# optimizer
cfg.num_steps = 10_000_000
cfg.lr = 1.0e-4
cfg.weight_decay = 0.0

# acceleration
cfg.use_ema = True
cfg.use_fp16 = False

# training
cfg.ema_decay = 0.9999
cfg.viz_num = 16
cfg.nrow = 4
cfg.viz_interval = 10_0
cfg.save_interval = 10_0

# logging
cfg.log_interval = 100
cfg.log_dir = f'test/log_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}'
