from matplotlib import pyplot as plt

def loss_figure(A):
    A = A.detach().numpy()
    plt.plot(A)
    return plt
