{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# custom dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import sys\n",
    "import os\n",
    "# Add the parent directory to the sys.path\n",
    "sys.path.append(os.path.abspath('../opendataval'))\n",
    "\n",
    "# Opendataval\n",
    "from dataloader import Register, DataFetcher, mix_labels, add_gauss_noise\n",
    "from dataval import (\n",
    "    AME,\n",
    "    DVRL,\n",
    "    BetaShapley,\n",
    "    DataBanzhaf,\n",
    "    DataOob,\n",
    "    DataShapley,\n",
    "    GradientShapley,\n",
    "    InfluenceSubsample,\n",
    "    KNNShapley,\n",
    "    LavaEvaluator,\n",
    "    LeaveOneOut,\n",
    "    RandomEvaluator,\n",
    "    RobustVolumeShapley,\n",
    ")\n",
    "\n",
    "from experiment import ExperimentMediator\n",
    "\n",
    "from model.api import ClassifierSkLearnWrapper\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.ensemble import RandomForestClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 1] Set up an environment"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Step 1-1] Register a dataset\n",
    "`ExperimentMediator.model_factory_setup()` is convenient, but it only works for datasets registered in `opendataval`. To apply `opendataval` to your custome datasets, a user first needs to register a dataset and define a `DataFetcher` from the registered dataset. The following codes demonstrate how to register a dataset from arrays of (random) features and (random) labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up hyperparameters\n",
    "dataset_name = \"random_dataset\"\n",
    "train_count, valid_count, test_count = 50, 10, 10\n",
    "noise_rate = 0.1\n",
    "# model_name = \"sklogreg\"\n",
    "model_name = \"LogisticRegression\"\n",
    "metric_name = \"accuracy\"\n",
    "\n",
    "# Generate a random dataset\n",
    "# Every element of X is generated from a standard Gaussian distribution\n",
    "X, y= np.random.normal(size=(100, 10)), np.random.choice([0,1], size=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Register a dataset from arrays X and y\n",
    "pd_dataset = Register(dataset_name=dataset_name, one_hot=True).from_data(X, y)\n",
    "\n",
    "# After regitering a dataset, we can define `DataFetcher` by its name.\n",
    "fetcher = (\n",
    "    DataFetcher(dataset_name, '../data_files/', False)\n",
    "    .split_dataset_by_count(train_count,\n",
    "                            valid_count,\n",
    "                            test_count)  \n",
    "    .noisify(mix_labels, noise_rate=noise_rate)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Step 1-2] Set up a prediction model\n",
    "Next is to set up a prediction model. With `ClassifierSkLearnWrapper` and `RegressionSkLearnWrapper`, any `sklearn` models can be utilized as a prediction model. The following code uses a random forest classifier `RandomForestClassifier`. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pred_model = ClassifierSkLearnWrapper(LogisticRegression, fetcher.label_dim[0]) # example of Logistic regression\n",
    "pred_model = ClassifierSkLearnWrapper(RandomForestClassifier, fetcher.label_dim[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Step 1-3] Combining all \n",
    "- Combining [Step 1-1] and [Step 1-2] with `ExperimentMediator`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med = ExperimentMediator(fetcher, pred_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 2] Compute data values\n",
    "`opendataval` provides various state-of-the-art data valuation algorithms. `ExperimentMediator.compute_data_values()` computes data values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_evaluators = [ \n",
    "    GradientShapley(),\n",
    "    RandomEvaluator(),\n",
    "    LeaveOneOut(), # leave one out\n",
    "    InfluenceSubsample(num_models=1000), # influence function\n",
    "    DVRL(rl_epochs=2000), # Data valuation using Reinforcement Learning\n",
    "    KNNShapley(k_neighbors=valid_count), # KNN-Shapley\n",
    "    DataShapley(gr_threshold=1.05, cache_name=f\"cached\"), # Data-Shapley\n",
    "    BetaShapley(gr_threshold=1.05, cache_name=f\"cached\"), # Beta-Shapley\n",
    "    DataBanzhaf(num_models=1000), # Data-Banzhaf\n",
    "    AME(num_models=1000), # Average Marginal Effects\n",
    "    DataOob(num_models=1000), # Data-OOB\n",
    "    LavaEvaluator(), # LAVA\n",
    "    # RobustVolumeShapley(gr_threshold=1.05) # VolumeShapley\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "# compute data values.\n",
    "exper_med = exper_med.compute_data_values(data_evaluators=data_evaluators)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## [Step 3] Evaluate data values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from opendataval.experiment.exper_methods import (\n",
    "    discover_corrupted_sample,\n",
    "    noisy_detection,\n",
    "    remove_high_low,\n",
    "    save_dataval\n",
    ")\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "# Saving the results\n",
    "output_dir = f\"../tmp/{dataset_name}_{noise_rate=}/\"\n",
    "exper_med.set_output_directory(output_dir)\n",
    "output_dir"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Noisy data detection \n",
    "- `noisy_detection` performs the noisy data detection task and evaluates the F1-score of each data valuation algorithm's prediction. The higher, the better.  \n",
    "  - noisy data: mislabeled data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med.evaluate(noisy_detection, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Discover noisy samples\n",
    "- `discover_corrupted_sample` visualizes how well noisy data points are identified when a fraction of dataset is inspected. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 35))\n",
    "_, fig = exper_med.plot(discover_corrupted_sample, fig, col=2, save_output=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Point removal experiment\n",
    "- `remove_high_low` performs the point removal experiment. Each data valution algorithm, it provides two curves: one is removing data in a descending order (orange), the other is in an ascending order (blue). As for the orange (resp. blue) curve, the lower (resp. upper), the better. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(15, 35))\n",
    "df_resp, fig = exper_med.plot(remove_high_low, fig, col=2, save_output=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_resp # it provides complete information for generating point-removal experiment figures."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save data values\n",
    "- `save_dataval` stores computed data values at `{output_dir}/save_dataval.csv`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exper_med.evaluate(save_dataval, save_output=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
