import json


def create_model_from_config(model_config):
    model_type = model_config['model_type']

    if model_type == 'autoencoder':
        from .autoencoders import create_autoencoder_from_config
        return create_autoencoder_from_config(model_config)
    else:
        raise NotImplementedError(f'Unknown model type: {model_type}')


def create_model_from_config_path(model_config_path):
    with open(model_config_path) as f:
        model_config = json.load(f)

    return create_model_from_config(model_config)


def create_pretransform_from_config(pretransform_config, sample_rate):
    pretransform_type = pretransform_config['type']

    if pretransform_type == 'autoencoder':
        from .autoencoders import create_autoencoder_from_config
        from .pretransforms import AutoencoderPretransform

        # Create fake top-level config to pass sample rate to autoencoder constructor
        # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
        autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
        autoencoder = create_autoencoder_from_config(autoencoder_config)

        scale = pretransform_config.get("scale", 1.0)
        model_half = pretransform_config.get("model_half", False)
        iterate_batch = pretransform_config.get("iterate_batch", False)
        chunked = pretransform_config.get("chunked", False)

        pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
    else:
        raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')

    enable_grad = pretransform_config.get('enable_grad', False)
    pretransform.enable_grad = enable_grad

    pretransform.eval().requires_grad_(pretransform.enable_grad)

    return pretransform


def create_bottleneck_from_config(bottleneck_config):
    bottleneck_type = bottleneck_config['type']

    if bottleneck_type == 'tanh':
        from .bottleneck import TanhBottleneck
        bottleneck = TanhBottleneck()
    elif bottleneck_type == 'vae':
        from .bottleneck import VAEBottleneck
        bottleneck = VAEBottleneck()
    else:
        raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')

    requires_grad = bottleneck_config.get('requires_grad', True)
    if not requires_grad:
        for param in bottleneck.parameters():
            param.requires_grad = False

    return bottleneck
