import logging

import torch

import utils

logger = logging.getLogger('custom')
lme = utils.log_mean_exp


def parse_unimodal_vae_params(args, modality: str):
    """ Processes params to allow instantiating UniModalVae object.
    :param args: hyperparameters
    :param modality: modality for which to construct VAE
    """
    return {
        'shapes': {'input': getattr(args, f'{modality}_input_shape')},
        'deterministic_layer': {
            'specs_bu': getattr(args, f'{modality}_det_specs_bu'),
            'specs_td': getattr(args, f'{modality}_det_specs_td')
        },
        'stochastic_layer': {
            'dim': args.stoc_dim[modality],
            'dist_type': args.stoc_dist,
            'specs': getattr(args, f'{modality}_stoc_specs'),
            'merge_layer': getattr(args, f'{modality}_merge_layer'),
            'upsampling': getattr(args, f'{modality}_stoc_upsampling'),
            'learn_prior': args.learn_prior,
        },
        'reconstruction_layer': {
            'dist_type': getattr(args, f'{modality}_rec_dist'),
            'specs': getattr(args, f'{modality}_rec_specs'),
        },
        'generic_layer': {
            'nonlin': getattr(args, f'{modality}_nonlin'),
        },
        'modality': modality}


def define_model(package, args, device, checkpoint=None):
    """
    :param package: includes everything model-specific
    :param args: for defining new model from args
    :return:
    """
    model = getattr(package, 'Model')
    model = model(args)

    if checkpoint:
        model.load_state_dict(checkpoint['state_dict'])

    return model.to(device)
