import numpy as np
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    log_loss,
    mean_absolute_error,
    mean_squared_error,
    precision_score,
    r2_score,
    recall_score,
    roc_auc_score,
)


class Utility:
    def __init__(self, utility_name: str, threshold: float = None):
        self.utility_name = utility_name
        self.threshold = threshold

    def _calibrate_threshold(self, y_prob, y_train):
        target_proportion = np.mean(y_train)
        sorted_probs = np.sort(y_prob)
        calibrated_threshold = sorted_probs[
            int((1 - target_proportion) * len(sorted_probs))
        ]
        return calibrated_threshold

    def _compute_classification_utility(self, y_prob, y_true, y_train):
        if self.threshold is None:
            threshold = self._calibrate_threshold(y_prob, y_train)
        else:
            threshold = self.threshold
        if y_prob.ndim == 1:
            y_pred = (y_prob >= threshold).float()
        else:
            y_pred = np.argmax(y_prob, axis=1)

        
        if self.utility_name == "accuracy":
            return accuracy_score(y_true, y_pred)
        elif self.utility_name == "f1":
            if np.unique(y_true).size > 2:
                return f1_score(y_true, y_pred, average="macro")
            else:
                return f1_score(y_true, y_pred)
        elif self.utility_name == "precision":
            return precision_score(y_true, y_pred)
        elif self.utility_name == "recall":
            if np.unique(y_true).size > 2:
                return recall_score(y_true, y_pred, average="macro")
            else:
                return recall_score(y_true, y_pred)
        elif self.utility_name == "neglogloss":
            return -log_loss(y_true, y_prob)
        elif self.utility_name == "auc":
            return roc_auc_score(y_true, y_prob)
        elif self.utility_name == "jaccard":
            y_pred = y_pred.numpy()
            y_true = y_true.numpy()
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            return tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
        elif self.utility_name == "gamma":
            y_pred = y_pred.numpy()
            return np.mean(y_pred)
        elif self.utility_name == "tp":
            y_pred = y_pred.numpy()
            y_true = y_true.numpy()
            return np.mean((y_true == 1) & (y_pred == 1))
        elif self.utility_name == "tpr":
            y_pred = y_pred.numpy()
            y_true = y_true.numpy()
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            return tp / (tp + fn) if (tp + fn) > 0 else 0.0
        elif self.utility_name == "am":
            y_pred = y_pred.numpy()
            y_true = y_true.numpy()
            tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
            tpr = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            tnr = tn / (tn + fp) if (tn + fp) > 0 else 0.0
            return (tpr + tnr) / 2

    def _compute_regression_utility(self, y_pred, y_true):
        if self.utility_name == "mae":
            return -mean_absolute_error(y_true, y_pred)
        elif self.utility_name == "mse":
            return -mean_squared_error(y_true, y_pred)
        elif self.utility_name == "r2":
            return r2_score(y_true, y_pred)

    def compute_utility(self, X_train, y_train, configuration):
        model_class = configuration["learning_setting"]["model"]
        optimizer_instance = configuration["learning_setting"]["optimizer"]
        criterion_instance = configuration["learning_setting"]["criterion"]
        model_params = configuration["learning_setting"].get("model_kwargs", {})
        model = model_class(
            input_dim=configuration["trainset"].X.shape[1],
            optimizer_instance=optimizer_instance,
            criterion_instance=criterion_instance,
            **model_params,
        )

        if model.is_classifier:
            unique_classes = np.unique(y_train)
            if len(unique_classes) < model.num_class or len(X_train) == 0:
                return 0
            model.fit(X_train, y_train)
            y_prob = model.predict(configuration["testset"].X)
            return self._compute_classification_utility(
                y_prob, configuration["testset"].y, y_train
            )
        else:
            if len(X_train) == 0:
                return 0
            model.fit(X_train, y_train)
            y_pred = model.predict(configuration["testset"].X)
            return self._compute_regression_utility(y_pred, configuration["testset"].y)
