import numpy as np
from sklearn.linear_model import LogisticRegression


def get_logr_utility(X_train, y_train, X_val, y_val, per_sample=True, n_classes=None):
    model = LogisticRegression(max_iter=1000, solver='liblinear')
    if not n_classes:
        n_classes = len(np.unique(y_val))

    if len(y_train) == 0:
        if not per_sample:
            return 1/n_classes
        else:
            return np.full(len(y_val), 1/n_classes)

    try:
        model.fit(X_train, y_train)
    except:
        if not per_sample:
            return 1/n_classes
        else:
            return np.full(len(y_val), 1/n_classes)

    if not per_sample:
        acc = model.score(X_val, y_val)
        return acc
    else:
        y = model.predict(X_val)
        return (y == y_val).astype(int)


def get_logr_utility_conditional(X_train, y_train, X_val, y_val, X_selected, y_selected, weight=None):
    model = LogisticRegression(max_iter=5000, solver='liblinear')
    X_train = np.concatenate((X_train, X_selected), axis=0)
    y_train = np.concatenate((y_train, y_selected), axis=0)
    
    if len(y_train) == 0:
        return 0.5

    try:
        model.fit(X_train, y_train, sample_weight=weight)
    except:
        return 0.5

    acc = model.score(X_val, y_val)
    return acc