import torch

def stable_rank(x):
    return (torch.linalg.matrix_norm(x, ord='fro') / torch.linalg.matrix_norm(x, ord=2)) ** 2

def relative_error(x, y, p='fro', dim=None, keepdim=False, out=None, dtype=None):
    error = torch.norm(x-y, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
    denorm =  1 / torch.norm(x+y, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
    return error * denorm

def add_bos_token(tokenizer, encoded_batch, attn_mask, device):
    bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
    encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
    attn_mask = torch.cat(
        [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
    )
    return encoded_batch, attn_mask

def print_cuda_max_memory(print_result=True):
    if print_result:
        print("torch.cuda.max_memory_allocated: %fGB"%(torch.cuda.max_memory_allocated(0)/1024/1024/1024))
    return