import torch
from sacred import Ingredient
from advas.models.generators import (GeneratorResidual, GeneratorImprovedMNIST,
                                     GeneratorBatchNormImprovedMNIST,
                                     GeneratorInstanceNormImprovedMNIST,
                                     GeneratorBNMNIST, GeneratorDCGAN,
                                     GeneratorStyleGan, GeneratorStyleGan2)

# has to come before importing generators
generator_ingredient = Ingredient('generator_model')


# @generator_ingredient.named_config
# def stylegan():
#     generator_type = 'stylegan'

# @generator_ingredient.named_config
# def stylegan2():
#     generator_type = 'stylegan2'

# @generator_ingredient.named_config
# def residual():
#     generator_type = 'residual'


@generator_ingredient.named_config
def improvedmnist():
    generator_type = 'improvedmnist'


@generator_ingredient.named_config
def improvedbnmnist():
    generator_type = 'improvedbnmnist'


@generator_ingredient.named_config
def improvedinmnist():
    generator_type = 'improvedinmnist'


@generator_ingredient.named_config
def bnmnist():
    generator_type = 'bnmnist'


@generator_ingredient.named_config
def inmnist():
    generator_type = 'inmnist'


@generator_ingredient.config
def cfg():
    generator_type = 'residual'
    base_shape = (1, 16, 16)
    style_dim = 512
    sigmoid_output = False


@generator_ingredient.capture
def generator_dispatch(data_shape, generator_type, style_dim, base_shape,
                       sigmoid_output, _log):

    if generator_type == "residual":
        p = GeneratorResidual(data_shape,
                              base_shape=torch.Size(base_shape),
                              sigmoid=sigmoid_output)
    elif generator_type == 'stylegan':
        p = GeneratorStyleGan(data_shape, style_dim=style_dim,
                              sigmoid=sigmoid_output)
    elif generator_type == 'stylegan2':
        p = GeneratorStyleGan2(data_shape[0], data_shape[1], style_dim)
    elif generator_type == 'improvedmnist':
        p = GeneratorImprovedMNIST()
    elif generator_type == 'bnmnist':
        p = GeneratorBNMNIST()
    elif generator_type == 'inmnist':
        p = GeneratorBNMNIST(norm_type='instance')
    elif generator_type == 'dcgan':
        p = GeneratorDCGAN(img_dim=28, nc=1, nz=100, ngf=64)
    elif generator_type == 'improvedinmnist':
        p = GeneratorInstanceNormImprovedMNIST()
    elif generator_type == 'improvedbnmnist':
        p = GeneratorBatchNormImprovedMNIST()
    else:
        raise NotImplementedError()

    _log.info(f"Modeling {generator_type}")

    return p
