import math

import torch


def _assert_both_none_or_tensor(a, b):
    assert not ((a is None) and (b is None))
    assert not ((a is not None) and (b is not None))


def get_math_module(value):
    if isinstance(value, float):
        return math
    elif isinstance(value, torch.Tensor):
        return torch
    else:
        raise NotImplementedError()


def _likelihood_constant(variance=None, log_sigma=None):
    _assert_both_none_or_tensor(variance, log_sigma)
    if variance is None:
        return -0.5 * math.log(2 * math.pi) - log_sigma
    else:
        log_sigma = get_math_module(variance).log(variance)
        const = -0.5 * math.log(2 * math.pi) - log_sigma
        return const


def compute_normal_log_prob(*args, variance=None, log_sigma=None):
    assert len(args) == 1
    x = args[0]  # (batchsize, num_target_points)
    _assert_both_none_or_tensor(variance, log_sigma)
    if variance is None:
        sigma = get_math_module(log_sigma).exp(log_sigma)
        variance = sigma**2
        const = _likelihood_constant(log_sigma=log_sigma)
        return const - 0.5 * x**2 / variance
    else:
        const = _likelihood_constant(variance=variance)
        return const - 0.5 * x**2 / variance
