import ot
import torch

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

def get_w2(samples1,samples2):
    n , m = samples1.shape[0], samples2.shape[0]
    M = ot.dist(samples1,samples2)
    a, b = torch.ones((n,),device=samples1.device) / n, torch.ones((m,),device=samples2.device) / m
    return ot.emd2(a,b,M)**.5