from sklearn.linear_model import LogisticRegression, RidgeClassifier, SGDClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

def train_and_evaluate_classifier(X_train, y_train, X_test, y_test, params):

    start_time = time.time()

    if params['model'] == 'lasso':
        # Can use liblinear or saga solvers
        model = LogisticRegressionCV(Cs=Cs, penalty='l1', solver='saga', max_iter=max_iter)
        model.fit(X_train, y_train)
        yhat_train = model.predict_proba(X_train)[:, 1]
        yhat_test = model.predict_proba(X_test)[:, 1]
        start = time.time()
        model.predict_proba(X_test[0:1])[:, 1]
        end = time.time()
        inference = (end - start)
        vars = np.sum(abs(model.coef_) > 0.01)

    if params['model'] == 'ridge':
        # Can use liblinear or saga solvers
        model = RidgeClassifierCV(alphas=alphas)
        model.fit(X_train, y_train)
        yhat_train = 1 / (1 + np.exp(-model.decision_function(X_train)))
        yhat_test = 1 / (1 + np.exp(-model.decision_function(X_test)))
        start = time.time()
        model.decision_function(X_test[0:1])
        end = time.time()
        inference = (end - start)
        vars = np.sum(abs(model.coef_) > 0.01)

    if params['model'] == 'enet':
        l1r = params['l1r']
        model = LogisticRegressionCV(
            Cs=Cs, 
            penalty='elasticnet', 
            solver='saga', 
            l1_ratios=[l1r], 
            max_iter=max_iter
        )
        model.fit(X_train, y_train)
        yhat_train = model.predict_proba(X_train)[:, 1]
        yhat_test = model.predict_proba(X_test)[:, 1]
        start = time.time()
        model.predict_proba(X_test[0:1])[:, 1]
        end = time.time()
        inference = (end - start)
        vars = np.sum(abs(model.coef_) > 0.01)

    if params['model'] == 'cart':
        model = DecisionTreeClassifier()
        model.fit(X_train, y_train)
        yhat_train = model.predict_proba(X_train)[:, 1]
        yhat_test = model.predict_proba(X_test)[:, 1]
        start = time.time()
        model.predict_proba(X_test[0:1])[:, 1]
        end = time.time()
        inference = (end - start)
        vars = 0

    if params['model'] == 'knn':
        K = params['k']
        model = KNeighborsClassifier(K)
        model.fit(X_train, y_train)
        yhat_train = model.predict_proba(X_train)[:, 1]
        yhat_test = model.predict_proba(X_test)[:, 1]
        start = time.time()
        model.predict_proba(X_test[0:1])[:, 1]
        end = time.time()
        inference = (end - start)
        vars = 0

    if params['model'] == 'rf':
        # TODO: Pick mtry by CV
        model = RandomForestClassifier(random_state=random_state)
        model.fit(X_train, y_train)
        yhat_train = model.predict_proba(X_train)[:, 1]
        yhat_test = model.predict_proba(X_test)[:, 1]
        start = time.time()
        model.predict_proba(X_test[0:1])[:, 1]
        end = time.time()
        inference = (end - start)
        vars = X_train.shape[1]

    if params['model'] == 'xgb':
        if params.get('early') == True:
            dtrain = xgb.DMatrix(X_train, label=y_train)
            xgb_params = {'objective': 'binary:logistic'}
            model = xgb.cv(xgb_params, dtrain, num_boost_round=100, nfold=num_folds)
            num_rounds = 1 + model['test-logloss-mean'].argmin()
        else:
            num_rounds = 100

        model = XGBClassifier(random_state=random_state, n_estimators=num_rounds)
        model.fit(X_train, y_train)
        yhat_train = model.predict_proba(X_train)[:, 1]
        yhat_test = model.predict_proba(X_test)[:, 1]
        start = time.time()
        model.predict_proba(X_test[0:1])[:, 1]
        end = time.time()
        inference = (end - start)
        inference = num_rounds
        vars = X_train.shape[1]

    if params['model'] == 'nid':
        nid = params['tmp_params']['nid']
        yhat_test = nid.forward(torch.tensor(X_test)).detach().numpy().flatten()
        yhat_test = 1 / (1 + np.exp(-yhat_test))
        start = time.time()
        nid.forward(torch.tensor(X_test[0:1])).detach().numpy().flatten()
        end = time.time()
        inference = (end - start)
        vars = X_train.shape[1]
    
    if params['model'] == 'irf':
        wrf = params['tmp_params']['irf']
        yhat_test = wrf.predict_proba(X_test)[:, 1]
        start = time.time()
        wrf.predict_proba(X_test[0:1])[:, 1]
        end = time.time()
        inference = (end - start)
        vars = X_train.shape[1]
    
    if params['model'] == 'pygam':
        model = LogisticGAM()
        model.fit(X_train, y_train)
        yhat_train = model.predict(X_train)
        yhat_test = model.predict(X_test)
        start = time.time()
        model.predict(X_test[0:1])
        end = time.time()
        inference = (end - start)
        vars = X_train.shape[1]

    if params['model'] == 'ours':

        if 'munge_params' in params:
            X_train = params['tmp_params']['X_munge']
            y_train = params['tmp_params']['y_munge']

        # Create the agumented design
        X_train_aug, X_test_aug = augment_X(X_train, X_test, params)
        
        if False: # seems to make little difference
            ss = StandardScaler()
            X_train_aug = ss.fit_transform(X_train_aug)
            X_test_aug = ss.transform(X_test_aug)

        model = RidgeClassifierCV(alphas=alphas, store_cv_values=True)
        model.fit(X_train_aug, y_train)

        # For some reason SKlearn doesn't seem to have the 1SE rule
        # But minimizing CV error flat out is not ideal
        # (as sometimes by chance we pick virtually zero regularization strength)
        # We protect against this by implementing a slight bump and thresholding
        cv_scores = model.cv_values_.mean(axis=0)
        cv_scores = cv_scores.flatten()
        threshold = pick_threshold(cv_scores, patience) + ridge_cv_bump
        alpha_idx = np.min(np.where(cv_scores < threshold))
        model = RidgeClassifierCV(alphas=[alphas[alpha_idx]])
        model.fit(X_train_aug, y_train)
        
        yhat_train = 1 / (1 + np.exp(-model.decision_function(X_train_aug)))
        yhat_test = 1 / (1 + np.exp(-model.decision_function(X_test_aug)))
        start = time.time()
        model.decision_function(X_test_aug[0:1])
        end = time.time()
        inference = (end - start)
        vars = np.sum(abs(model.coef_) > 0.01)

    end_time = time.time()

    acc_bl = np.mean(np.concatenate([y_train, y_test]))
    acc_bl = max(acc_bl, 1 - acc_bl)
    acc = np.mean(y_test == np.round(yhat_test))
    auc = roc_auc_score(y_test, yhat_test)

    res = {
        'acc_bl': [acc_bl],
        'acc': [acc],
        'auc': [auc],
        'vars': [vars],
        'runtime': [end_time - start_time],
        'inference': [inference],
    }

    return res
