import warnings

import numpy as np
from rdkit import rdBase
from skfp.datasets.tdc.benchmark import load_tdc_dataset, load_tdc_splits
from skfp.metrics import extract_pos_proba, multioutput_auroc_score
from skfp.preprocessing import MolFromSmilesTransformer
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.pipeline import make_pipeline

from mol2vec.mol2vec_fp import Mol2VecFingerprint


def get_classification_dataset_names() -> list[str]:
    return [
        "ames",
        "bioavailability_ma",
        "cyp1a2_veith",
        "cyp2c9_veith",
        "cyp2c9_substrate_carbonmangels",
        "cyp2c19_veith",
        "cyp2d6_veith",
        "cyp2d6_substrate_carbonmangels",
        "cyp3a4_veith",
        "cyp3a4_substrate_carbonmangels",
        "dili",
        "herg",
        "herg_karim",
        "hia_hou",
        "pampa_ncats",
        "pgp_broccatelli",
        "sarscov2_3clpro_diamond",
        "sarscov2_vitro_touret",
    ]


def get_regression_dataset_names() -> list[str]:
    return [
        "caco2_wang",
        "clearance_hepatocyte_az",
        "clearance_microsome_az",
        "half_life_obach",
        "ld50_zhu",
        "ppbr_az",
        "solubility_aqsoldb",
        "vdss_lombardo",
    ]


def parse_data(
    dataset_name: str,
    smiles_list: list[str],
    y: np.ndarray,
) -> tuple[list[str], list[str], np.ndarray, np.ndarray]:
    train_idxs, valid_idxs, test_idxs = load_tdc_splits(dataset_name)
    smiles_list = np.array(smiles_list)

    smiles_train = smiles_list[train_idxs + valid_idxs]
    smiles_test = smiles_list[test_idxs]

    y_train = y[train_idxs + valid_idxs]
    y_test = y[test_idxs]

    y_train = np.nan_to_num(y_train, nan=0)

    return smiles_train, smiles_test, y_train, y_test


if __name__ == "__main__":
    # turn off unnecessary warnings
    rdBase.DisableLog("rdApp.*")
    warnings.simplefilter(action="ignore", category=FutureWarning)

    for model_type, model_path in [
        ("original", "outputs/mol2vec_original.pkl"),
        ("retrained", "outputs/mol2vec.model"),
    ]:
        print(model_type)

        auroc_scores = []
        for dataset_name in get_classification_dataset_names():
            smiles_list, y = load_tdc_dataset(dataset_name)

            smiles_train, smiles_test, y_train, y_test = parse_data(
                dataset_name, smiles_list, y
            )

            pipeline = make_pipeline(
                MolFromSmilesTransformer(suppress_warnings=True),
                Mol2VecFingerprint(model_path),
                RandomForestClassifier(
                    n_estimators=500,
                    criterion="entropy",
                    n_jobs=-1,
                    random_state=0,
                ),
            )
            pipeline.fit(smiles_train, y_train)
            y_pred = pipeline.predict_proba(smiles_test)
            y_pred = extract_pos_proba(y_pred)

            auroc = multioutput_auroc_score(y_test, y_pred, suppress_warnings=True)
            auroc_scores.append(auroc)
            print(f"{dataset_name} AUROC: {auroc:.2%}")

        auroc = np.mean(auroc_scores)
        print(f"Average TDC AUROC: {auroc:.2%}")
        print()

        mae_scores = []
        for dataset_name in get_regression_dataset_names():
            smiles_list, y = load_tdc_dataset(dataset_name)
            smiles_train, smiles_test, y_train, y_test = parse_data(
                dataset_name, smiles_list, y
            )

            pipeline = make_pipeline(
                MolFromSmilesTransformer(suppress_warnings=True),
                Mol2VecFingerprint(model_path),
                RandomForestRegressor(n_estimators=500, n_jobs=-1, random_state=0),
            )
            pipeline.fit(smiles_train, y_train)
            y_pred = pipeline.predict(smiles_test)

            mae = mean_absolute_error(y_test, y_pred)
            mae_scores.append(mae)
            print(f"{dataset_name} MAE: {mae:.3f}")

        mae = np.mean(mae_scores)
        print(f"Average TDC MAE: {auroc:.3f}")
        print()
