import torch
from functools import partial


def get_source(src, **kwargs):
    # select source function
    if src == 'example1':
        source_func = partial(source1, mu=kwargs['mu'])
        ymax = 1.0
    elif src == 'example2':
        source_func = source2
        ymax = 0.2
    else:
        raise NotImplementedError
    return source_func


def source1(x, mu=0, sigma=0.1):
    """normalized gaussian distribution"""
    return torch.exp(-0.5 * (x - mu) ** 2 / (sigma ** 2))


def source2(x, mu=0, sigma=0.05):
    """normalized gaussian distribution"""
    return torch.exp(-0.5 * (x - mu) ** 2 / (sigma ** 2)) / 5
