import numpy as np
import torch
from sklearn.linear_model import LogisticRegression


class MLP(torch.nn.Module):

    def __init__(self, max_iter=50, hidden_size=8):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(1, hidden_size),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_size, 1)
        )
        self.max_iter = max_iter

    def forward(self, x):
        return self.network(x)

    def fit(self, logits, true_labels):

        logits = self._ensure_tensor(logits).view(-1, 1)
        true_labels = self._ensure_tensor(true_labels).float()

        device = logits.device
        self.to(device)

        optimizer = torch.optim.LBFGS(self.parameters(),
                                      max_iter=self.max_iter)
        loss_fn = torch.nn.BCEWithLogitsLoss()

        def closure():
            optimizer.zero_grad()
            calibrated = self(logits).squeeze()

            loss = loss_fn(calibrated, true_labels)
            loss.backward()
            return loss

        optimizer.step(closure)

    def predict(self, logits):

        logits = self._ensure_tensor(logits).view(-1, 1)
        with torch.no_grad():
            outputs = self(logits).squeeze()
            return torch.sigmoid(outputs)

    def _ensure_tensor(self, x):
        if not isinstance(x, torch.Tensor):
            x = torch.tensor(x, dtype=torch.float32)
        else:
            x = x.to(dtype=torch.float32)  # 新增强制类型转换
        return x.detach().clone()


def calibrate(train_logits, train_labels, test_logits, *args, **kwargs):
    """
    Convenience function for Platt Scaling calibration

    Args:
        train_logits: Logits for training the calibrator
        train_labels: True labels for training
        test_logits: Logits to be calibrated

    Returns:
        Calibrated probabilities for test_logits
    """
    model = MLP()
    model.fit(train_logits, train_labels)
    return {"logits": model.predict(test_logits)}
