import torch

from sacred import Ingredient
from advas.models.proxies import (TestGanWassersteinProxy as TWP,
                                  ProxyImprovedMNIST, ProxyBEGAN,
                                  ProxyBatchNormImprovedMNIST,
                                  ProxyInstanceNormImprovedMNIST,
                                  ProxyBNMNIST, ProxyDCGAN,
                                  ProxyStyleGan2)
import torch.optim as optim

# has to come before importing models
proxy_method_ingredient = Ingredient('proxy_method')


@proxy_method_ingredient.named_config
def began():
    proxy_type = "began"

@proxy_method_ingredient.named_config
def inbegan():
    proxy_type = "inbegan"

@proxy_method_ingredient.named_config
def improvedmnist():
    proxy_type = "improvedmnist"


@proxy_method_ingredient.named_config
def improvedbnmnist():
    proxy_type = "improvedbnmnist"


@proxy_method_ingredient.named_config
def improvedinmnist():
    proxy_type = "improvedinmnist"

# @proxy_method_ingredient.named_config
# def stylegan2():
#     proxy_type = "stylegan2"


@proxy_method_ingredient.named_config
def bnmnist():
    proxy_type = "bnmnist"

@proxy_method_ingredient.named_config
def inmnist():
    proxy_type = "inmnist"

@proxy_method_ingredient.config
def cfg():
    optimizer_type = "ADAM"
    # 'improvedmnist' is Improved WGAN MNIST arch.
    proxy_type = "improvedmnist"
    lr = 1e-4


@proxy_method_ingredient.capture
def optimizer_dispatch(proxy_model, optimizer_type, lr, _log):
    if optimizer_type == "ADAM":
        opt = optim.Adam(proxy_model.parameters(), lr=lr, betas=(0.5, 0.999))
    elif optimizer_type == "RMSPROP":
        opt = optim.RMSprop(proxy_model.parameters(), lr=lr)
    elif optimizer_type == "SGD":
        opt = optim.SGD(proxy_model.parameters(), lr=lr)
    elif optimizer_type == "SGD":
        opt = optim.SGD(proxy_model.parameters(), lr=lr)
    else:
        raise NotImplementedError("Unknown optimizer")
    _log.info(f"Running proxy with optimizer: {optimizer_type}")
    return opt


@proxy_method_ingredient.capture
def proxy_method(data_shape, proxy_type, _log):

    if proxy_type == 'standard':
        proxy_model = TWP(data_shape)
    if proxy_type == 'stylegan2':
        proxy_model = ProxyStyleGan2(data_shape[1])
    elif proxy_type == 'improvedmnist':
        proxy_model = ProxyImprovedMNIST()
    elif proxy_type == 'bnmnist':
        proxy_model = ProxyBNMNIST()
    elif proxy_type == 'inmnist':
        proxy_model = ProxyBNMNIST(norm_type='instance')
    elif proxy_type == 'dcgan':
        proxy_model = ProxyDCGAN(img_dim=28, nc=1, ndf=64)
    elif proxy_type == 'began':
        proxy_model = ProxyBEGAN(data_shape)
    elif proxy_type == 'inbegan':
        proxy_model = ProxyBEGAN(data_shape, norm_type='instance')
    elif proxy_type ==  'improvedinmnist':
        proxy_model = ProxyInstanceNormImprovedMNIST()
    elif proxy_type ==  'improvedbnmnist':
        proxy_model = ProxyBatchNormImprovedMNIST()
    else:
        raise NotImplementedError(f"Invalid proxy_type, {proxy_type}.")
    norm = torch.cat([p.flatten() for p in proxy_model.parameters()]).norm()
    msg = f"Proxy: ||w||={norm.item()}"
    _log.info(msg)
    _log.info(f"Proxy Type: {proxy_type}")
    optimizer = optimizer_dispatch(proxy_model)
    return proxy_model, optimizer
