


import torch
assert torch.__version__ >= '1.6.0'
import torch.nn as nn


def layer_norm(d_model, condition=True):
    return nn.LayerNorm(d_model) if condition else None


def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask
