from code.utils import HParams
from code.realnvp_v2 import RealNVP


_HP_CELEBAHQ64_5BIT = HParams(
    # RealNVP architecture
    image_shape             = (3, 64, 64),
    d_hidden                = 80,
    n_blocks                = 10,
    n_scales                = 6,

    # Data processing
    logit_eps               = 0.001,
    n_bits                  = 5,

    # Training
    full_batch_size         = 32,
    learning_rate           = 1e-3,
    learning_rate_decay     = [25, 50],
    l2_coeff                = 5e-4,
    clip_grad_norm          = 200,
    max_epoch               = 200,

    # Monitoring
    print_freq              = 10,
    log_freq                = 50,
    sample_freq             = 1,
    ckpt_freq               = 1,
)

class RealNVP_CelebAHQ64_5bit(RealNVP):
    hp = HParams(_HP_CELEBAHQ64_5BIT)

    def __init__(self):
        super().__init__(image_shape=self.hp.image_shape,
                         d_hidden=self.hp.d_hidden,
                         n_blocks=self.hp.n_blocks,
                         n_scales=self.hp.n_scales,
                         logit_eps=self.hp.logit_eps)
