{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# adult dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "# Opendataval\n",
    "from opendataval.dataloader import Register, DataFetcher, mix_labels, add_gauss_noise\n",
    "from opendataval.dataval import (\n",
    "    AME,\n",
    "    DVRL,\n",
    "    BetaShapley,\n",
    "    DataBanzhaf,\n",
    "    DataOob,\n",
    "    DataShapley,\n",
    "    InfluenceSubsample,\n",
    "    InfluenceFunction,\n",
    "    GradientShapley,\n",
    "    KNNShapley,\n",
    "    LavaEvaluator,\n",
    "    LeaveOneOut,\n",
    "    RandomEvaluator\n",
    ")\n",
    "\n",
    "from opendataval.experiment import ExperimentMediator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 1] Set up an environment\n",
    "`ExperimentMediator` is a fundamental concept in establishing the `opendataval` environment. It empowers users to configure hyperparameters, including a dataset, a type of synthetic noise, and a prediction model. With  `ExperimentMediator`, users can effortlessly compute various data valuation algorithms.\n",
    "\n",
    "The following code cell demonstrates how to set up `ExperimentMediator` with a pre-registered dataset and a prediction model.\n",
    "- Dataset: adult\n",
    "- Model: sklearn's logistic regression model\n",
    "- Metric: Classification accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name = \"adult\"\n",
    "train_count, valid_count, test_count = 1000, 100, 500\n",
    "noise_rate = 0.2\n",
    "noise_kwargs = {'noise_rate': noise_rate}\n",
    "model_name = \"sklogreg\"\n",
    "metric_name = \"accuracy\"\n",
    "\n",
    "exper_med = ExperimentMediator.model_factory_setup(\n",
    "    dataset_name=dataset_name,\n",
    "    cache_dir=\"../data_files/\",  \n",
    "    force_download=False,\n",
    "    train_count=train_count,\n",
    "    valid_count=valid_count,\n",
    "    test_count=test_count,\n",
    "    add_noise=mix_labels, \n",
    "    noise_kwargs=noise_kwargs,\n",
    "    train_kwargs={},\n",
    "    model_name=model_name,\n",
    "    metric_name=metric_name\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A full list of registered datasets can be found [here](https://github.com/opendataval/opendataval/blob/main/opendataval/dataloader/fetcher.py#L121). A list of available  prediction models can be found [here](https://github.com/opendataval/opendataval/blob/main/opendataval/model/__init__.py#L111)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 2] Compute data values\n",
    "`opendataval` provides various state-of-the-art data valuation algorithms. `ExperimentMediator.compute_data_values()` computes data values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_evaluators = [ \n",
    "    RandomEvaluator(),\n",
    "    LeaveOneOut(), # leave one out\n",
    "    InfluenceSubsample(num_models=1000), # influence function\n",
    "    # InfluenceFunction(), # influence function\n",
    "    DVRL(rl_epochs=2000), # Data valuation using Reinforcement Learning\n",
    "    KNNShapley(k_neighbors=valid_count), # KNN-Shapley\n",
    "    DataShapley(cache_name=f\"cached\"), # Data-Shapley\n",
    "    BetaShapley(cache_name=f\"cached\"), # Beta-Shapley\n",
    "    DataBanzhaf(num_models=1000), # Data-Banzhaf\n",
    "    AME(num_models=1000), # Average Marginal Effects\n",
    "    DataOob(num_models=1000), # Data-OOB\n",
    "    LavaEvaluator()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# compute data values.\n",
    "exper_med = exper_med.compute_data_values(data_evaluators=data_evaluators)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 3] Evaluate data values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from opendataval.experiment.exper_methods import (\n",
    "    discover_corrupted_sample,\n",
    "    noisy_detection,\n",
    "    remove_high_low,\n",
    "    save_dataval\n",
    ")\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Saving the results\n",
    "output_dir = f\"../tmp/{dataset_name}_{noise_rate=}/\"\n",
    "exper_med.set_output_directory(output_dir)\n",
    "output_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Noisy data detection \n",
    "- `noisy_detection` performs the noisy data detection task and evaluates the F1-score of each data valuation algorithm's prediction. The higher, the better.  \n",
    "  - noisy data: mislabeled data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med.evaluate(noisy_detection, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Discover noisy samples\n",
    "- `discover_corrupted_sample` visualizes how well noisy data points are identified when a fraction of dataset is inspected. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 40))\n",
    "df, fig = exper_med.plot(discover_corrupted_sample, fig, col=2, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Point removal experiment\n",
    "- `remove_high_low` performs the point removal experiment. Each data valution algorithm, it provides two curves: one is removing data in a descending order (orange), the other is in an ascending order (blue). As for the orange (resp. blue) curve, the lower (resp. upper), the better. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 40))\n",
    "df_resp, fig = exper_med.plot(remove_high_low, fig, col=2, save_output=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_resp # it provides complete information for generating point-removal experiment figures."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save data values\n",
    "- `save_dataval` stores computed data values at `{output_dir}/save_dataval.csv`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med.evaluate(save_dataval, save_output=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
