import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.pipeline import make_pipeline

from chemberta.chemberta_fp import ChemBERTaFingerprint


def load_data() -> tuple[list[str], list[str], np.ndarray, np.ndarray]:
    df_train = pd.read_csv("chemberta/apistox_datasets/apistox_time_train.csv")
    df_test = pd.read_csv("chemberta/apistox_datasets/apistox_time_test.csv")

    smiles_train = df_train["SMILES"]
    smiles_test = df_test["SMILES"]

    y_train = df_train["label"]
    y_test = df_test["label"]

    return smiles_train, smiles_test, y_train, y_test


if __name__ == "__main__":
    smiles_train, smiles_test, y_train, y_test = load_data()

    for model_type, model_path in [
        ("original", "DeepChem/ChemBERTa-77M-MLM"),
        ("retrained", "outputs/chemberta_model_mlm"),
    ]:
        print(model_type)

        pipeline = make_pipeline(
            ChemBERTaFingerprint(model_path),
            RandomForestClassifier(
                n_estimators=500, criterion="entropy", n_jobs=-1, random_state=0
            ),
        )
        pipeline.fit(smiles_train, y_train)

        y_pred = pipeline.predict(smiles_test)
        y_pred_proba = pipeline.predict_proba(smiles_test)[:, 1]

        auroc = roc_auc_score(y_test, y_pred_proba)
        print(f"AUROC: {auroc:.2%}")
        print()
