import numpy as np
import torch
from sklearn.isotonic import IsotonicRegression


class IsotonicCalibration:
    def __init__(self):
        self.isotonic_model = IsotonicRegression(out_of_bounds='clip')
        self.probs_train = None
        self.true_train = None

    def fit(self, probs, true):

        probs_np = probs.numpy() if isinstance(probs, torch.Tensor) else np.array(probs)
        true_np = true.numpy() if isinstance(true, torch.Tensor) else np.array(true)

        sort_idx = np.argsort(probs_np)
        self.probs_train = probs_np[sort_idx]
        self.true_train = true_np[sort_idx]

        self.isotonic_model.fit(self.probs_train, self.true_train)

    def predict(self, logits):

        logits_np = logits.numpy() if isinstance(logits, torch.Tensor) else np.array(logits)
        calibrated = self.isotonic_model.predict(logits_np)
        return torch.from_numpy(calibrated.astype(np.float32))


def calibrate(train_logits, train_labels, test_logits, *args, ** kwargs):

    if train_logits.numel() == 0 or train_labels.numel() == 0 or test_logits.numel() == 0:
        return {"logits":test_logits}

    model = IsotonicCalibration()
    model.fit(train_logits, train_labels)
    return {"logits": model.predict(test_logits)}