import warnings

import numpy as np
from rdkit import rdBase
from skfp.datasets.moleculenet import (
    load_moleculenet_benchmark,
    load_ogb_splits,
)
from skfp.metrics import extract_pos_proba, multioutput_auroc_score
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.pipeline import make_pipeline

from chemberta.chemberta_fp import ChemBERTaFingerprint


def parse_data(
    dataset_name: str,
    smiles_list: list[str],
    y: np.ndarray,
) -> tuple[list[str], list[str], np.ndarray, np.ndarray]:
    train_idxs, valid_idxs, test_idxs = load_ogb_splits(dataset_name)
    smiles_list = np.array(smiles_list)

    smiles_train = smiles_list[train_idxs + valid_idxs]
    smiles_test = smiles_list[test_idxs]

    y_train = y[train_idxs + valid_idxs]
    y_test = y[test_idxs]

    y_train = np.nan_to_num(y_train, nan=0)

    return smiles_train, smiles_test, y_train, y_test


if __name__ == "__main__":
    # turn off unnecessary warnings
    rdBase.DisableLog("rdApp.*")
    warnings.simplefilter(action="ignore", category=FutureWarning)

    for model_type, model_path in [
        ("original", "DeepChem/ChemBERTa-77M-MLM"),
        ("retrained", "outputs/chemberta_model_mlm"),
    ]:
        print(model_type)

        auroc_scores = []
        for dataset_name, smiles_list, y in load_moleculenet_benchmark(
            subset="classification_no_pcba"
        ):
            smiles_train, smiles_test, y_train, y_test = parse_data(
                dataset_name, smiles_list, y
            )

            pipeline = make_pipeline(
                ChemBERTaFingerprint(model_path, verbose=True),
                RandomForestClassifier(
                    n_estimators=500,
                    criterion="entropy",
                    n_jobs=-1,
                    random_state=0,
                ),
            )
            pipeline.fit(smiles_train, y_train)
            y_pred = pipeline.predict_proba(smiles_test)
            y_pred = extract_pos_proba(y_pred)

            auroc = multioutput_auroc_score(y_test, y_pred, suppress_warnings=True)
            auroc_scores.append(auroc)
            print(f"{dataset_name} AUROC: {auroc:.2%}")

        auroc = np.mean(auroc_scores)
        print(f"Average MoleculeNet AUROC: {auroc:.2%}")

        mae_scores = []
        for dataset_name, smiles_list, y in load_moleculenet_benchmark(
            subset="regression"
        ):
            smiles_train, smiles_test, y_train, y_test = parse_data(
                dataset_name, smiles_list, y
            )

            pipeline = make_pipeline(
                ChemBERTaFingerprint(model_path),
                RandomForestRegressor(n_estimators=500, n_jobs=-1, random_state=0),
            )
            pipeline.fit(smiles_train, y_train)
            y_pred = pipeline.predict(smiles_test)

            mae = mean_absolute_error(y_test, y_pred)
            mae_scores.append(mae)
            print(f"{dataset_name} MAE: {mae:.3f}")

        mae = np.mean(mae_scores)
        print(f"Average MoleculeNet MAE: {mae:.3f}")
