import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
from scipy.special import softmax
from skopt import BayesSearchCV
from skopt.space import Real

from calibration_methods.linear_calibrator import LinearCalibrator


# Define a custom scorer for negative log loss that works with our classifier
def neg_log_loss_scorer(estimator, X, y):
    try:
        probs = estimator.predict_proba(X).detach().cpu().numpy()
        # Ensure output is probabilities
        if not np.allclose(probs.sum(axis=1), 1.0, atol=1e-5):
            probs = softmax(probs, axis=1)
        return -log_loss(y, probs)
    except Exception as e:
        print(f"Scoring error: {str(e)}")
        return float("-inf")


def MatrixScaling():
    calibrator = LinearCalibrator(use_logits=True, reg_lambda=0.0)
    return calibrator


# https://github.com/dirichletcal/experiments_dnn/blob/master/scripts/tune_cal_odir.py
# The commented out lines below correspond to the default parameters. However the grid search for the hyper-parameters is too
# large for experiments in ImageNet. Therefore, we restrict the grid search to be more manageable.


def MatrixScalingODIR(num_classes):
    # Use a single validation split instead of cross-validation
    calibrator = LinearCalibrator(use_logits=True, odir=True)

    # Define parameter space for Bayesian optimization
    param_space = {
        "reg_lambda": Real(1e-4, 1e3, prior="log-uniform"),
        "reg_mu": Real(1e-4, 1e3, prior="log-uniform"),
    }

    # Use Bayesian optimization with a single validation split
    opt = BayesSearchCV(
        calibrator,
        param_space,
        n_iter=20,  # Total number of iterations
        cv=2,  # Single validation split
        scoring=neg_log_loss_scorer,
        refit=True,
        verbose=1,
        n_jobs=-1,  # Re-enable parallel processing
        random_state=42,
    )
    return opt
