
from sklearn.linear_model import LogisticRegression as lr
from sklearn.svm import SVC as svm
from sklearn.ensemble import RandomForestClassifier as rfc
from util import get_results_sur, get_results_surclust, get_results_EKR, focal_loss, weight_oversample, class_loss
from imblearn.over_sampling import SMOTE, BorderlineSMOTE, ADASYN, KMeansSMOTE
from sklearn.neural_network import MLPClassifier as MLP
from imblearn.over_sampling import RandomOverSampler 
from imblearn.under_sampling import ClusterCentroids, TomekLinks
from imblearn.ensemble import EasyEnsembleClassifier, BalancedBaggingClassifier
from xgboost import XGBClassifier as xgb


def Models(weighted, mod_typ, x):           #Base classifiers
    
    args = dict()
    if weighted == 1:
        args['class_weight'] = 'balanced'
    
    if mod_typ == 'lr':
        model = lr(max_iter = 1000000, fit_intercept = False, **args)
    elif mod_typ == 'rf':
        model = rfc(random_state = 42, **args)
    elif mod_typ == 'svm':
        model = svm(probability = True, random_state = 42, **args)
    elif mod_typ == 'nn':
        model = MLP(hidden_layer_sizes = (int((len(x.columns) + 2) / 2),) * 2, random_state = 42, max_iter = 1000000)
    elif mod_typ == 'xgb':
        model = xgb(n_jobs = 1, **args)        
        
    return model


def ORG(x, y, reweight = 0, mod_typ = 'lr'):        #train on balanced or imbalance dataset, with or without reweighting
    
    model = Models(reweight, mod_typ, x)
    
    if reweight == 1 and mod_typ == 'nn':
        ros = RandomOverSampler(sampling_strategy='auto', random_state=42)
        xn, yn = ros.fit_resample(x, y)
        model.fit(xn, yn)
    elif reweight == 'focal':
        fweights = focal_loss(x, y, model)
        if mod_typ != 'nn':
            model.fit(x, y, fweights)
        else:
            xn, yn = weight_oversample(x, y, fweights[y == 1])
            model.fit(xn, yn)
    elif reweight == 'class':
        cweights = class_loss(x, y)
        if mod_typ != 'nn':
            model.fit(x, y, cweights)
        else:
            xn, yn = weight_oversample(x, y, cweights[y == 1])
            model.fit(xn, yn)
    else:
        model.fit(x, y)
            
    return model
  


def Sur_Ens(x, y, xs, ys, x_ns, clust = 0, weighted = 0, mod_typ = 'lr'):               #Function that defines RBE, CBE, or EKR  
        
    model = Models(weighted, mod_typ, x)
        
    if clust == 0:
        proba_train, proba_test = get_results_sur(model, x, y, xs, ys, 1, len(x.columns))
    elif clust == 1:
        proba_train, proba_test = get_results_surclust(model, x, y, xs, ys, x_ns, 1, len(x.columns))
    elif clust == 'ekr':
        proba_train, proba_test = get_results_EKR(model, x, y, xs, ys, len(x.columns))

    return proba_train.loc[:, 1], proba_test.loc[:, 1] 
            




def Smote(x, y, ver = 'gen', mod_typ = 'lr'):               #Function for oversampling methods
            
    model = Models(0, mod_typ, x)

    knn = int(min(5, sum(y) - 1))
            
    if ver == 'gen':
        oversample = SMOTE(random_state = 0, k_neighbors=knn)
    elif ver == 'bord':
        oversample = BorderlineSMOTE(random_state = 0, k_neighbors=knn)
    elif ver == 'asd':
        oversample = ADASYN(random_state = 0, n_neighbors=knn)
    elif ver == 'kmean':
        oversample = KMeansSMOTE(random_state = 0, k_neighbors=knn)
    
    xs, ys = oversample.fit_resample(x, y)
    
    model.fit(xs, ys)
                        
    return model, xs, ys
    
  
    
def UnderSample(x, y, ver = 'clus', mod_typ = 'lr'):            #Function for undersampling methods
            
    model = Models(0, mod_typ, x)
      
    if ver == 'clust':
        undersamp = ClusterCentroids(random_state=42)
    elif ver == 'tomek':
        undersamp = TomekLinks()
     
    xs, ys = undersamp.fit_resample(x, y)
    
    model.fit(xs, ys)
                        
    return model, xs, ys

    

def Ensemble(x, y, ver = 'easy', mod_typ = 'lr'):           #Function for off-the-shelf ensembling methods
    
    model = Models(0, mod_typ, x)
      
    if ver == 'easy':
        ensbl = EasyEnsembleClassifier(estimator=model, random_state=42)
    elif ver == 'bag':
        ensbl = BalancedBaggingClassifier(estimator=model, random_state=42)
        
    ensbl.fit(x, y)
    
    return ensbl
         