from typing import List

import torch


def create_gaussion_dist(dims: List[int], sigma: float, mu: float, device: int):
    axis = [torch.linspace(-1, 1, dim) for dim in dims]
    mesh_points = torch.stack(torch.meshgrid(*axis, indexing="ij"))
    mesh_points = mesh_points - mu
    return torch.exp(-(mesh_points**2).sum(dim=0) / (2.0 * sigma**2)).to(device=device)
