import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import log_loss
from scipy.special import softmax
from skopt import BayesSearchCV
from skopt.space import Real

from calibration_methods.linear_calibrator import LinearCalibrator


# Define a custom scorer for negative log loss that works with our classifier
def neg_log_loss_scorer(estimator, X, y):
    try:
        probs = estimator.predict_proba(X).detach().cpu().numpy()
        # Ensure output is probabilities
        if not np.allclose(probs.sum(axis=1), 1.0, atol=1e-5):
            probs = softmax(probs, axis=1)
        return -log_loss(y, probs)
    except Exception as e:
        print(f"Scoring error: {str(e)}")
        return float("-inf")


# See note in matrix_scaling.py regarding hyper-parameters
def DirichletODIR(num_classes):
    # Use a single validation split instead of cross-validation
    calibrator = LinearCalibrator(use_logits=False, odir=True)

    # Define parameter space for Bayesian optimization
    param_space = {
        "reg_lambda": Real(1e-4, 1e3, prior="log-uniform"),
        "reg_mu": Real(1e-4, 1e3, prior="log-uniform"),
    }

    # Use Bayesian optimization with a single validation split
    opt = BayesSearchCV(
        calibrator,
        param_space,
        cv=2,
        n_iter=12,  # Total number of iterations
        scoring=neg_log_loss_scorer,
        refit=True,
        verbose=1,
        n_jobs=-1,  # Re-enable parallel processing
        random_state=42,
    )
    return opt


def DirichletL2(num_classes):
    # Use a single validation split instead of cross-validation
    calibrator = LinearCalibrator(use_logits=False, odir=False)

    # Define parameter space for Bayesian optimization
    param_space = {"reg_lambda": Real(1e-4, 1e3, prior="log-uniform")}

    # Use Bayesian optimization with a single validation split
    opt = BayesSearchCV(
        calibrator,
        param_space,
        cv=2,
        n_iter=12,  # Total number of iterations
        scoring=neg_log_loss_scorer,
        refit=True,
        verbose=1,
        n_jobs=-1,  # Re-enable parallel processing
        random_state=42,
    )
    return opt
