import pandas as pd
import numpy as np
import warnings
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, accuracy_score, recall_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
import lightgbm as lgb
from xgboost import XGBClassifier
import copy

from CHRE.utils import Data_Entropy, cg_Global, KnnAndRknn, cal_target_value, CG_UnderSample, CG_OverSamples, \
    OverSamples_SMOTE_ENTROPY_, remove_syn_min, Entropy_all_MinMax, Is_Break, CHRE_ensemble, KNN_nei_Matrix_Pseudo, \
    KNN_nei_Matrix, RKNN_nei, loadfile, Stratified_fold_K_version_2

warnings.filterwarnings("ignore")


def chre_resampling(X_train, y_train, X_test, y_test, classifier, lamba_aim1=0.8, lamba_aim2=0.5, K=5, iter_T=10):
    model = copy.deepcopy(classifier)
    Entropy_list = Entropy_all_MinMax(k=K)
    train_D = Data_Entropy(X_train, y_train)
    test_D = Data_Entropy(X_test, y_test)

    Global_entropy = cg_Global()
    model.fit(train_D.x, train_D.y)
    Global_entropy.model_list.append(copy.deepcopy(model))
    Global_entropy.Data_list.append(copy.deepcopy(train_D))
    Global_entropy.Entropy_list = Entropy_list
    Global_entropy.K = K
    Global_entropy.lamba_aim1 = lamba_aim1
    Global_entropy.lamba_aim2 = lamba_aim2

    test_D.pred = model.predict(test_D.x)
    if hasattr(model, 'predict_proba'):
        test_D.pred_prob = model.predict_proba(test_D.x)
    else:
        test_D.pred_prob = np.zeros((len(test_D.x), 2))
        test_D.pred_prob[:, 1] = model.predict(test_D.x)
        test_D.pred_prob[:, 0] = 1 - test_D.pred_prob[:, 1]
    train_D.pred = model.predict(train_D.x)

    nn_list = KnnAndRknn(train_D, K=K, Gloabl_item=Global_entropy)
    No_T = 0

    while No_T < iter_T:
        Global_entropy.No_T = No_T
        cal_target_value(nn_list, train_D, Global_entropy)
        Pro_Data = CG_UnderSample(nn_list, train_D, Gloabl_item=Global_entropy)
        Pro_Data = CG_OverSamples(nn_list, train_D, test_D, Pro_Data, Global_item=Global_entropy)
        Pro_Data = OverSamples_SMOTE_ENTROPY_(Pro_Data, nn_list, Global_entropy)
        remove_syn_min(Pro_Data, Global_entropy)

        model.fit(Pro_Data.x, Pro_Data.y)

        No_T += 1
        Global_entropy.model_list.append(copy.deepcopy(model))
        Global_entropy.Data_list.append(copy.deepcopy(Pro_Data))
        break_is = Is_Break(Global_entropy, train_D)
        if No_T < 2: break_is = 0
        if break_is == 1 or break_is == -1: break

    return CHRE_ensemble(Global_entropy.model_list, train_D, test_D)


def evaluate_model(classifier, X, y):
    all_data = Stratified_fold_K_version_2(X, y, n_spli=5)
    metrics_summary = {
        'test_auc': [], 'test_accuracy': [], 'test_recall': [], 'test_f1': [], 'test_gmean': []
    }

    for fold in range(5):
        X_train, y_train = all_data["train_x"][fold], all_data["train_y"][fold]
        X_test, y_test = all_data["test_x"][fold], all_data["test_y"][fold]

        y_test_pred = chre_resampling(X_train, y_train, X_test, y_test, classifier)

        metrics_summary['test_auc'].append(roc_auc_score(y_test, y_test_pred))
        metrics_summary['test_accuracy'].append(accuracy_score(y_test, y_test_pred))
        metrics_summary['test_recall'].append(recall_score(y_test, y_test_pred))
        metrics_summary['test_f1'].append(f1_score(y_test, y_test_pred))

        tn, fp, fn, tp = confusion_matrix(y_test, y_test_pred).ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        gmean = np.sqrt(specificity * recall_score(y_test, y_test_pred))
        metrics_summary['test_gmean'].append(gmean)

    print('#' * 20)
    metrics_results = {}
    for metric in metrics_summary:
        mean_value = np.mean(metrics_summary[metric])
        std_value = np.std(metrics_summary[metric], ddof=1)
        print(f'Mean {metric}: {mean_value:.4f}, Std {metric}: {std_value:.4f}')
        metrics_results[metric] = {'mean': np.around(mean_value, 4), 'std': np.around(std_value, 4)}

    return metrics_results


data = pd.read_csv('CDC Diabetes Health Indicators Dataset/data.csv')
X = data.drop('Diabetes_binary', axis=1).to_numpy()
y = data['Diabetes_binary'].to_numpy()

scaler = StandardScaler()
X = scaler.fit_transform(X)

classifiers = {
    "LightGBM": lgb.LGBMClassifier(verbosity=-1, random_state=42),
    "XGBoost": XGBClassifier(random_state=42),
    "Random Forest": RandomForestClassifier(random_state=42),
    "Decision Tree": DecisionTreeClassifier(random_state=42)
}

results = []
for name, clf in classifiers.items():
    print(f"==================CHRE with {name}===================")
    metrics = evaluate_model(clf, X, y)
    metrics['Classifier'] = f'CHRE-{name}'
    results.append(metrics)

results_df = pd.DataFrame(results)
print(results_df)
