import datasets
import numpy as np
from scipy.stats import pearsonr
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import root_mean_squared_error
from sklearn.pipeline import make_pipeline

from chemberta.chemberta_fp import ChemBERTaFingerprint


def load_data() -> tuple[list[str], list[str], np.ndarray, np.ndarray]:
    dataset = datasets.load_dataset("karlleswing/toxbench", "computational")

    df_train = dataset["train"].to_pandas()
    df_valid = dataset["validation"].to_pandas()
    df_test = dataset["test"].to_pandas()

    smiles_train = df_train["smiles"].tolist()
    smiles_valid = df_valid["smiles"].tolist()
    smiles_test = df_test["smiles"].tolist()

    y_train = df_train["abfep_affinity"].to_numpy()
    y_valid = df_valid["abfep_affinity"].to_numpy()
    y_test = df_test["abfep_affinity"].to_numpy()

    smiles_train = smiles_train + smiles_valid
    y_train = np.concatenate((y_train, y_valid))

    return smiles_train, smiles_test, y_train, y_test


if __name__ == "__main__":
    smiles_train, smiles_test, y_train, y_test = load_data()

    for model_type, model_path in [
        ("original", "DeepChem/ChemBERTa-77M-MLM"),
        ("retrained", "outputs/chemberta_model_mlm"),
    ]:
        print(model_type)

        pipeline = make_pipeline(
            ChemBERTaFingerprint(model_path),
            RandomForestRegressor(n_estimators=500, n_jobs=-1, random_state=0),
        )
        pipeline.fit(smiles_train, y_train)

        y_pred = pipeline.predict(smiles_test)

        corr = pearsonr(y_test, y_pred).statistic
        rmse = root_mean_squared_error(y_test, y_pred)

        print(f"Pearson R: {corr:.3f}; RMSE: {rmse:.3f}")
        print()
