import datasets
import numpy as np
from scipy.stats import pearsonr
from skfp.preprocessing import MolFromSmilesTransformer
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import root_mean_squared_error
from sklearn.pipeline import make_pipeline

from mol2vec.mol2vec_fp import Mol2VecFingerprint


def load_data() -> tuple[list[str], list[str], np.ndarray, np.ndarray]:
    dataset = datasets.load_dataset("karlleswing/toxbench", "computational")

    df_train = dataset["train"].to_pandas()
    df_valid = dataset["validation"].to_pandas()
    df_test = dataset["test"].to_pandas()

    smiles_train = df_train["smiles"].tolist()
    smiles_valid = df_valid["smiles"].tolist()
    smiles_test = df_test["smiles"].tolist()

    y_train = df_train["abfep_affinity"].to_numpy()
    y_valid = df_valid["abfep_affinity"].to_numpy()
    y_test = df_test["abfep_affinity"].to_numpy()

    smiles_train = smiles_train + smiles_valid
    y_train = np.concatenate((y_train, y_valid))

    return smiles_train, smiles_test, y_train, y_test


if __name__ == "__main__":
    smiles_train, smiles_test, y_train, y_test = load_data()

    for model_type, model_path in [
        ("original", "outputs/mol2vec_original.pkl"),
        ("retrained", "outputs/mol2vec.model"),
    ]:
        print(model_type)

        pipeline = make_pipeline(
            MolFromSmilesTransformer(n_jobs=-1, suppress_warnings=True),
            Mol2VecFingerprint(model_path, n_jobs=-1),
            RandomForestRegressor(n_estimators=500, n_jobs=-1, random_state=0),
        )
        pipeline.fit(smiles_train, y_train)

        y_pred = pipeline.predict(smiles_test)

        corr = pearsonr(y_test, y_pred).statistic
        rmse = root_mean_squared_error(y_test, y_pred)

        print(f"Pearson R: {corr:.3f}; RMSE: {rmse:.3f}")
        print()
