import numpy as np
import torch
import torch.nn as nn
import torchvision
import geomloss
from geomloss import SamplesLoss
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def line(n=80):
    return "="*n
