import warnings

import numpy as np
import pandas as pd
import torch
from datasets import tqdm
from rdkit import rdBase
from skfp.metrics import bedroc_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score

from chemberta.chemberta_fp import ChemBERTaFingerprint


def load_welqrate_datasets() -> list[
    tuple[str, list[str], np.ndarray, list[tuple[list[int], list[int]]]]
]:
    datasets_dir = "chemberta/welqrate_datasets"
    splits_dir = "chemberta/welqrate_datasets/scaffold_split_idxs"

    assays = [
        "AID1798",
        "AID1843",
        "AID2258",
        "AID2689",
        "AID435008",
        "AID435034",
        "AID463087",
        "AID485290",
        "AID488997",
    ]

    all_datasets = []
    for assay in assays:
        smiles_active = pd.read_csv(
            f"{datasets_dir}/{assay}_actives.csv",
            usecols=["SMILES"],
        )["SMILES"].to_numpy()
        smiles_inactive = pd.read_csv(
            f"{datasets_dir}/{assay}_inactives.csv",
            usecols=["SMILES"],
        )["SMILES"].to_numpy()

        smiles_all = np.concatenate((smiles_active, smiles_inactive))

        y_pos = np.ones(len(smiles_active))
        y_neg = np.zeros(len(smiles_inactive))
        y_all = np.concatenate((y_pos, y_neg))

        split_idxs = []
        for seed in [1, 2, 3, 4, 5]:
            idxs = torch.load(f"{splits_dir}/{assay}_2d_scaffold_seed{seed}.pt")
            train_idxs = idxs["train"] + idxs["valid"]
            test_idxs = idxs["test"]
            split_idxs.append((train_idxs, test_idxs))

        all_datasets.append((assay, smiles_all, y_all, split_idxs))

    return all_datasets


if __name__ == "__main__":
    # turn off unnecessary warnings
    rdBase.DisableLog("rdApp.*")
    warnings.simplefilter(action="ignore", category=FutureWarning)

    datasets = load_welqrate_datasets()

    for model_type, model_path in [
        ("original", "DeepChem/ChemBERTa-77M-MLM"),
        ("retrained", "outputs/chemberta_model_mlm"),
    ]:
        print(model_type)

        fp_pipeline = ChemBERTaFingerprint(model_path, verbose=True)
        benchmark_auroc_scores = []
        benchmark_bedroc_scores = []

        for assay_name, smiles_all, y_all, split_idxs in datasets:
            print(assay_name)

            auroc_scores = []
            bedroc_scores = []

            X = fp_pipeline.transform(smiles_all)

            for idxs_train, idxs_test in tqdm(split_idxs):
                X_train = X[idxs_train]
                X_test = X[idxs_test]

                y_train = y_all[idxs_train]
                y_test = y_all[idxs_test]

                clf = RandomForestClassifier(
                    n_estimators=500,
                    criterion="entropy",
                    n_jobs=-1,
                    random_state=0,
                )
                clf.fit(X_train, y_train)
                y_pred_proba = clf.predict_proba(X_test)[:, 1]

                auroc = roc_auc_score(y_test, y_pred_proba)
                bedroc = bedroc_score(y_test, y_pred_proba, alpha=20)

                auroc_scores.append(auroc)
                bedroc_scores.append(bedroc)

            auroc = np.mean(auroc_scores)
            bedroc = np.mean(bedroc_scores)
            print(f"AUROC: {auroc:.2%}; BEDROC: {bedroc:.2%}")

            benchmark_auroc_scores.append(auroc)
            benchmark_bedroc_scores.append(bedroc)

        auroc = np.mean(benchmark_auroc_scores)
        bedroc = np.mean(benchmark_bedroc_scores)

        print(f"AUROC: {auroc:.2%}")
        print(f"BEDROC: {bedroc:.2%}")
