import numpy as np
import torch
from sklearn.linear_model import LogisticRegression


class PlattScaling:

    def __init__(self, max_iter=50):
        self.a = torch.nn.Parameter(torch.tensor(1.0))
        self.b = torch.nn.Parameter(torch.tensor(0.0))
        self.max_iter = max_iter

    def fit(self, logits, true_labels):

        logits = self._ensure_tensor(logits).view(-1, 1)
        true_labels = self._ensure_tensor(true_labels).float()

        self._move_params_to_device(logits.device)
        optimizer = torch.optim.LBFGS([self.a, self.b],
                                      # lr=self.lr,
                                      max_iter=self.max_iter)
        loss_fn = torch.nn.BCEWithLogitsLoss()

        def closure():
            optimizer.zero_grad()
            calibrated = logits * self.a + self.b  # sigmoid in BCE loss

            loss = loss_fn(calibrated.squeeze(-1), true_labels)
            loss.backward()
            return loss

        optimizer.step(closure)

    def predict(self, logits):

        logits = self._ensure_tensor(logits).view(-1, 1)
        calibrated = logits * self.a + self.b
        return torch.sigmoid(calibrated).squeeze()

    def _ensure_tensor(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        return x.detach().clone()

    def _move_params_to_device(self, device):
        self.a.data = self.a.to(device)
        self.b.data = self.b.to(device)



def calibrate(train_logits, train_labels, test_logits, *args, **kwargs):
    """
    Convenience function for Platt Scaling calibration

    Args:
        train_logits: Logits for training the calibrator
        train_labels: True labels for training
        test_logits: Logits to be calibrated

    Returns:
        Calibrated probabilities for test_logits
    """
    model = PlattScaling()
    model.fit(train_logits, train_labels)
    return {"logits": model.predict(test_logits)}
