from code.utils import HParams
from code.realnvp_v2 import RealNVP


_HP_CIFAR10 = HParams(
    # RealNVP architecture
    image_shape             = (3, 32, 32),
    d_hidden                = 64,
    n_blocks                = 12,
    n_scales                = 6,

    # Data processing
    logit_eps               = 0.001,
    n_bits                  = 8,

    # Training
    full_batch_size         = 64,
    learning_rate           = 1e-3,
    learning_rate_decay     = [25, 50],
    l2_coeff                = 1e-5,
    clip_grad_norm          = 500,
    max_epoch               = 300,

    # Monitoring
    print_freq              = 10,
    log_freq                = 50,
    sample_freq             = 1,
    ckpt_freq               = 1,
)


_HP_CIFAR10_5BIT = _HP_CIFAR10.clone()
_HP_CIFAR10_5BIT.n_bits = 5
_HP_CIFAR10_5BIT.full_batch_size = 128


class RealNVP_CIFAR10(RealNVP):
    hp = HParams(_HP_CIFAR10)

    def __init__(self):
        super().__init__(image_shape=self.hp.image_shape,
                         d_hidden=self.hp.d_hidden,
                         n_blocks=self.hp.n_blocks,
                         n_scales=self.hp.n_scales,
                         logit_eps=self.hp.logit_eps)


class RealNVP_CIFAR10_5bit(RealNVP):
    hp = HParams(_HP_CIFAR10_5BIT)

    def __init__(self):
        super().__init__(image_shape=self.hp.image_shape,
                         d_hidden=self.hp.d_hidden,
                         n_blocks=self.hp.n_blocks,
                         n_scales=self.hp.n_scales,
                         logit_eps=self.hp.logit_eps)
