from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler, OneHotEncoder, MinMaxScaler
import numpy as np

def classLabel2IdxInProb(y, classes):
    '''
    this function is used to convert class labels to indices in the probability vector, for calculating roc_auc_score
    y: (n_samples,)
    classes: list of classes
    return:
        y_idx: (n_samples,)
    '''
    y_idx = np.zeros(y.shape)
    for i, c in enumerate(classes):
        y_idx[y == c] = i
    return y_idx
def ProbIdx2ClassLabel(y_idx, classes):
    '''
    this function is used to convert indices in the probability vector to class labels
    y_idx: (n_samples,)
    classes: list of classes
    return:
        y: (n_samples,)
    '''
    if isinstance(y_idx, np.ndarray):
        y_idx = y_idx.tolist()
    return np.array([classes[i] for i in y_idx])

def evaluate(pred_probs, answers, multiclass=False):   
    if multiclass == False:
        if len(pred_probs.shape) == 2:
            result_auc = roc_auc_score(answers, pred_probs[:, 1])
        else:
            result_auc = roc_auc_score(answers, pred_probs)
    else:
        result_auc = roc_auc_score(answers, pred_probs, multi_class='ovr', average='macro')        
    return result_auc


def preprocess(support_x_n, query_x_n, support_x_c, query_x_c, N_cols, C_cols, scaler_type='standard'):
    if len(N_cols) > 0:
        if scaler_type == 'standard':
            scaler = StandardScaler()
            support_x_n = scaler.fit_transform(support_x_n)
            query_x_n = scaler.transform(query_x_n)
        elif scaler_type == 'minmax':
            scaler = MinMaxScaler()
            support_x_n = scaler.fit_transform(support_x_n)
            query_x_n = scaler.transform(query_x_n)
        elif scaler_type == 'none':
            scaler = None
        else:
            raise ValueError(f'scaler_type should be standard or minmax, got {scaler_type}')
        

    if len(C_cols) > 0:
        ohe = OneHotEncoder()
        # need all data to fit onehotencoder, to avoid missing categories in test data
        all_data = np.concatenate([support_x_c, query_x_c], axis=0)
        ohe.fit(all_data)
        support_x_c = ohe.transform(support_x_c).toarray()
        query_x_c = ohe.transform(query_x_c).toarray()


    if len(N_cols) > 0 and len(C_cols) > 0:
        X_train = np.concatenate((support_x_n, support_x_c), axis=1)
        X_test = np.concatenate([query_x_n, query_x_c], axis=1)
    elif len(N_cols) > 0:
        X_train = support_x_n
        X_test = query_x_n
    elif len(C_cols) > 0:
        X_train = support_x_c
        X_test = query_x_c
    return X_train, X_test