{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# pol dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# Opendataval\n",
    "from opendataval.dataloader import Register, DataFetcher, mix_labels, add_gauss_noise\n",
    "from opendataval.dataval import (\n",
    "    AME,\n",
    "    DVRL,\n",
    "    BetaShapley,\n",
    "    DataBanzhaf,\n",
    "    DataOob,\n",
    "    DataShapley,\n",
    "    InfluenceSubsample,\n",
    "    KNNShapley,\n",
    "    LavaEvaluator,\n",
    "    LeaveOneOut,\n",
    "    RandomEvaluator,\n",
    "    RobustVolumeShapley,\n",
    ")\n",
    "\n",
    "from opendataval.experiment import ExperimentMediator\n",
    "\n",
    "from opendataval.model.api import ClassifierSkLearnWrapper\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 1] Set up an environment\n",
    "`ExperimentMediator` is a fundamental concept in establishing the `opendataval` environment. It empowers users to configure hyperparameters, including a dataset, a type of synthetic noise, and a prediction model. With  `ExperimentMediator`, users can effortlessly compute various data valuation algorithms.\n",
    "\n",
    "The following code cell demonstrates how to set up `ExperimentMediator` with a pre-registered dataset and a prediction model.\n",
    "- Dataset: pol\n",
    "- Model: sklearn's support vector machine\n",
    "- Metric: Classification accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up hyperparameters\n",
    "dataset_name = \"pol\"\n",
    "train_count, valid_count, test_count = 1000, 100, 300\n",
    "noise_rate = 0.2\n",
    "noise_kwargs = {'noise_rate': noise_rate}\n",
    "model_name = \"sklogreg\"\n",
    "metric_name = \"accuracy\"\n",
    "\n",
    "fetcher = DataFetcher.setup(\n",
    "    dataset_name=dataset_name,\n",
    "    cache_dir=\"../data_files/\",  \n",
    "    force_download=False,\n",
    "    train_count=train_count,\n",
    "    valid_count=valid_count,\n",
    "    test_count=test_count,\n",
    "    add_noise=mix_labels,\n",
    "    noise_kwargs=noise_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Step 1-2] Set up a prediction model\n",
    "Next is to set up a prediction model. With `ClassifierSkLearnWrapper` and `RegressionSkLearnWrapper`, any `sklearn` models can be utilized as a prediction model. The following code uses a random forest classifier `GradientBoostingClassifier` with custom parameters. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model paramters for GradientBoostingClassifier can be applied as follows.\n",
    "pred_model = ClassifierSkLearnWrapper(GradientBoostingClassifier, \n",
    "                                      fetcher.label_dim[0],\n",
    "                                      loss='log_loss', \n",
    "                                      learning_rate=0.05, \n",
    "                                      n_estimators=200,)\n",
    "\n",
    "# We can check parameters with the following codes.\n",
    "print('learning_rate: ', pred_model.model.get_params()['learning_rate'])\n",
    "print('n_estimators: ', pred_model.model.get_params()['n_estimators'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Step 1-3] Combining all \n",
    "- Combining [Step 1-1] and [Step 1-2] with `ExperimentMediator`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med = ExperimentMediator(fetcher, pred_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 2] Compute data values\n",
    "`opendataval` provides various state-of-the-art data valuation algorithms. `ExperimentMediator.compute_data_values()` computes data values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_evaluators = [ \n",
    "    RandomEvaluator(),\n",
    "    LeaveOneOut(), # leave one out\n",
    "    InfluenceSubsample(num_models=1000), # influence function\n",
    "#     DVRL(rl_epochs=2000), # Data valuation using Reinforcement Learning\n",
    "    KNNShapley(k_neighbors=valid_count), # KNN-Shapley\n",
    "#     DataShapley(gr_threshold=1.05, cache_name=f\"cached\"), # Data-Shapley\n",
    "#     BetaShapley(gr_threshold=1.05, cache_name=f\"cached\"), # Beta-Shapley\n",
    "#     DataBanzhaf(num_models=1000), # Data-Banzhaf\n",
    "#     AME(num_models=1000), # Average Marginal Effects\n",
    "    DataOob(num_models=1000), # Data-OOB\n",
    "    LavaEvaluator(),\n",
    "#     RobustVolumeShapley(gr_threshold=1.05)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# compute data values.\n",
    "exper_med = exper_med.compute_data_values(data_evaluators=data_evaluators)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 3] Evaluate data values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from opendataval.experiment.exper_methods import (\n",
    "    discover_corrupted_sample,\n",
    "    noisy_detection,\n",
    "    remove_high_low,\n",
    "    save_dataval\n",
    ")\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Saving the results\n",
    "output_dir = f\"../tmp/{dataset_name}_{noise_rate=}/\"\n",
    "exper_med.set_output_directory(output_dir)\n",
    "output_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Noisy data detection \n",
    "- `noisy_detection` performs the noisy data detection task and evaluates the F1-score of each data valuation algorithm's prediction. The higher, the better.  \n",
    "  - noisy data: mislabeled data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med.evaluate(noisy_detection, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Discover noisy samples\n",
    "- `discover_corrupted_sample` visualizes how well noisy data points are identified when a fraction of dataset is inspected. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 20))\n",
    "_, fig = exper_med.plot(discover_corrupted_sample, fig, col=2, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Point removal experiment\n",
    "- `remove_high_low` performs the point removal experiment. Each data valution algorithm, it provides two curves: one is removing data in a descending order (orange), the other is in an ascending order (blue). As for the orange (resp. blue) curve, the lower (resp. upper), the better. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 20))\n",
    "df_resp, fig = exper_med.plot(remove_high_low, fig, col=2, save_output=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_resp # it provides complete information for generating point-removal experiment figures."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save data values\n",
    "- `save_dataval` stores computed data values at `{output_dir}/save_dataval.csv`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med.evaluate(save_dataval, save_output=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
