import numpy as np

from sklearn.metrics import f1_score, roc_auc_score, average_precision_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import lightgbm
from sklearn.linear_model import LogisticRegression

lgb = lightgbm.LGBMClassifier(
                            class_weight='balanced', 
                            objective='binary', 
                            n_jobs=20,
                            random_state=2023,
                            )


X_train = np.load("../proc_data/icu/raw_train_data.npy")
y_train = np.load("../proc_data/icu/train_label.npy")
X_test = np.load("../proc_data/icu/raw_test_data.npy")
y_test = np.load("../proc_data/icu/test_label.npy")

lgb_train = lgb.fit(X_train, y_train)
pred_train = lgb_train.predict(X_test)
print("AUROC (real data): " + str(np.around(roc_auc_score(y_test, lgb_train.predict_proba(X_test)[:, 1]), 4)))
print("AUPRC (real data): " + str(np.around(average_precision_score(y_test, lgb_train.predict_proba(X_test)[:, 1]), 4)))
print("**********************************")


syn_data = np.load("EHRDiff.npy")

syn_data = np.clip(syn_data, 0, 1)
syn_data[:, 0] = np.rint(syn_data[:, 0]) # label
syn_data[:, 2] = np.rint(syn_data[:, 2]) # gender
syn_icu_type = syn_data[:, 5:8]
max_ind = np.argmax(syn_icu_type, axis=1)
syn_icu_type[:, :] = 0
for m in range(len(max_ind)):
    syn_icu_type[m, max_ind[m]] = 1
syn_data[:, 5:8] = syn_icu_type
train = np.load("../proc_data/icu/raw_train_data_label.npy")
scaler = MinMaxScaler()
scaler.fit(train[:, 1:])
syn_data = scaler.inverse_transform(syn_data)

num_each_label = syn_data.shape[0] // 2
labels = np.array([0 for _ in range(num_each_label)] + [1 for _ in range(num_each_label)])

lgb_syn = lgb.fit(syn_data, labels)
pred_syn = lgb_syn.predict(X_test)
print("AUROC (syn data): " + str(np.around(roc_auc_score(y_test, lgb_syn.predict_proba(X_test)[:, 1]), 4)))
print("AUPRC (syn data): " + str(np.around(average_precision_score(y_test, lgb_syn.predict_proba(X_test)[:, 1]), 4)))
print("**********************************")


# aug_X = np.concatenate((X_train, syn_data))
# aug_y = np.concatenate((y_train, labels))
# lgb_aug = lgb.fit(aug_X, aug_y)
# pred_aug = lgb_aug.predict(X_test)
# print("AUROC (aug data): " + str(np.around(roc_auc_score(y_test, lgb_aug.predict_proba(X_test)[:, 1]), 4)))
# print("AUPRC (aug data): " + str(np.around(average_precision_score(y_test, lgb_aug.predict_proba(X_test)[:, 1]), 4)))
# print("**********************************")
