#Norm!/usr/bin/env python3
import numpy as np
import itertools
import random
import torch
import nltk

def assert_close(a, b):
    out_mask = torch.abs(a - b) > 1e-3
    out_a = a[out_mask]
    out_b = b[out_mask]
    assert (not out_mask.any()), \
        "\na:\n%s\nb:\n%s\nidxs:\n%s" % (out_a, out_b, out_mask.nonzero())

