from code.utils import HParams
from code.realnvp_v2 import RealNVP


_HP_MNIST = HParams(
    # RealNVP architecture
    image_shape             = (1, 28, 28),
    d_hidden                = 32,
    n_blocks                = 8,
    n_scales                = 3,

    # Data processing
    logit_eps               = 0.001,
    n_bits                  = 8,

    # Training
    full_batch_size         = 128,
    learning_rate           = 1e-3,
    learning_rate_decay     = [20, 40, 60],
    l2_coeff                = 1e-3,
    clip_grad_norm          = 100,
    max_epoch               = 200,

    # Monitoring
    print_freq              = 10,
    log_freq                = 20,
    sample_freq             = 1,
    ckpt_freq               = 5,
)

class RealNVP_MNIST(RealNVP):
    hp = HParams(_HP_MNIST)

    def __init__(self):
        super().__init__(image_shape=self.hp.image_shape,
                         d_hidden=self.hp.d_hidden,
                         n_blocks=self.hp.n_blocks,
                         n_scales=self.hp.n_scales,
                         logit_eps=self.hp.logit_eps)

