import torch



def calibrate(train_logits,
              train_labels,
              test_logits,
              *args, **kwargs):

    tau = torch.nn.Parameter(torch.tensor(1.0))
    b = torch.nn.Parameter(torch.tensor(1.0))

    optimizer = torch.optim.LBFGS([tau,b],
                                  line_search_fn="strong_wolfe",
                                  max_iter=50)

    def closure():
        optimizer.zero_grad()
        loss_fn = torch.nn.BCEWithLogitsLoss()
        loss = loss_fn(train_logits / tau+b, train_labels.float())

        loss.backward()
        return loss

    optimizer.step(closure=closure)
    final_loss = closure()

    if torch.isnan(tau):
        tau = 1
    else:
        tau = tau.item()
    print(tau, b.item(), final_loss.item())
    return {"logits": torch.sigmoid(test_logits / tau+b)}
