from sacred import Ingredient
import advas.objectives as obj
from advas import NormalizeType

# has to come before importing models
objective_ingredient = Ingredient('objective')


@objective_ingredient.named_config
def lsgan():
    objective_type = 'LSGAN'


@objective_ingredient.named_config
def began():
    objective_type = 'BEGAN'


@objective_ingredient.named_config
def wgan_norm():
    objective_type = 'WGAN_norm'
    weight_norm = 1.


@objective_ingredient.named_config
def wgan_gp():
    objective_type = 'WGAN_GP'
    GP_strength = 10.


@objective_ingredient.named_config
def wgan_clamp():
    objective_type = 'WGAN_clamp'
    clamp_limit = 0.01


@objective_ingredient.config
def cfg():
    GP_strength = None
    regularizer_strength = 0.
    unbiased = True
    ignore_proxy_reg = False
    do_sqrt = False
    clamp_limit = None
    weight_norm = None
    objective_type = 'WGAN_GP'


@objective_ingredient.capture
def objective_dispatch(GP_strength, clamp_limit, weight_norm,
                       regularizer_strength, unbiased, ignore_proxy_reg,
                       do_sqrt, objective_type):
    if regularizer_strength in [-1, -2]:
        if regularizer_strength == -1:
            normtype = NormalizeType.Total
        elif regularizer_strength == -2:
            normtype = NormalizeType.Advas
        regularizer_strength = 1
    else:
        normtype = NormalizeType.Standard

    if 'WGAN' in objective_type:
        objective = obj.WGan(GP_strength=GP_strength, clamp_limit=clamp_limit,
                             weight_norm=weight_norm,
                             regularizer_strength=regularizer_strength,
                             unbiased=unbiased, do_sqrt=do_sqrt,
                             ignore_proxy_reg=ignore_proxy_reg)
    elif objective_type == 'LSGAN':
        objective = obj.LSGan(regularizer_strength=regularizer_strength,
                              unbiased=unbiased, do_sqrt=do_sqrt,
                              ignore_proxy_reg=ignore_proxy_reg, )
    elif objective_type == 'BEGAN':
        objective = obj.BEGan(regularizer_strength=regularizer_strength,
                              ignore_proxy_reg=ignore_proxy_reg,
                              do_sqrt=do_sqrt, unbiased=unbiased)
    else:
        raise ValueError("Invalid objective type")
    return objective, normtype
