def loss_forward(pred, label, min_max = None):
    theta = [-2.0193, -1.2234, 0.1363, 0.1269, -0.4566, -0.1016, -0.2545, 1.0971, -0.9203, 0.2368, 0.4795, 0.9975]
    theta = torch.tensor(theta).float()
    num_cat = pred.shape[1]
    num_param = len(theta)
    theta_i = [theta[idx].expand(1, num_cat) for idx in range(num_param)]
​
    pred = pred.unsqueeze(1)
    label_onehot = torch.FloatTensor(pred.shape[0], num_cat)
    label_onehot.zero_()
    label_onehot.scatter_(1, label.unsqueeze(-1).type(torch.LongTensor), 1)
    label_onehot = label_onehot.unsqueeze(1)
​
    loss = theta_i[2] * (pred - theta_i[1]) \
          + theta_i[3] * ((pred - theta_i[1]) ** 2) \
          + theta_i[4] * ((pred - theta_i[1]) ** 3) \
          + theta_i[5] * (label_onehot - theta_i[0]) * (pred - theta_i[1]) \
          + theta_i[6] * (label_onehot - theta_i[0]) * ((pred - theta_i[1]) ** 2) \
          + theta_i[7] * ((label_onehot - theta_i[0]) ** 2) * (pred - theta_i[1]) \
          + theta_i[8] * (pred - theta_i[1]) ** 4 \
          + theta_i[9] * (pred - theta_i[1]) ** 3 * (label_onehot - theta_i[0]) \
          + theta_i[10] * (pred - theta_i[1]) ** 2 * (label_onehot - theta_i[0]) ** 2 \
          + theta_i[11] * (pred - theta_i[1]) * (label_onehot - theta_i[0]) ** 3
    if loss.shape[0] == 1:
        loss = -(loss.squeeze().sum() / num_cat)
    else:
        loss = -(loss.squeeze().sum(1) /num_cat)
​
    if min_max:
        return ((loss - min_max[0]) / (min_max[1] - min_max[0]))
    else:
        return loss