{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "825e919e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from pathlib import Path\n",
    "\n",
    "# --------------- Setting ---------------\n",
    "REMOVE_FRAC = 0.01\n",
    "RANK_METRIC = \"confidence_weighted_entropy\"\n",
    "RANDOM_STATE = 0\n",
    "OUTDIR = Path(\"./mnist_clean_test\")\n",
    "OUTDIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# --------------- 1) Read the data ---------------\n",
    "from torchvision.datasets import MNIST\n",
    "\n",
    "train_ds = MNIST(root=\"./data\", train=True,  download=True)\n",
    "test_ds  = MNIST(root=\"./data\", train=False, download=True)\n",
    "\n",
    "X_train = train_ds.data.numpy().astype(np.float32) / 255.0\n",
    "y_train = train_ds.targets.numpy().astype(int)\n",
    "X_test  = test_ds.data.numpy().astype(np.float32) / 255.0\n",
    "y_test  = test_ds.targets.numpy().astype(int)\n",
    "\n",
    "N_train = len(train_ds)\n",
    "N_test  = len(test_ds)\n",
    "X_train = X_train.reshape(N_train, -1)  # [N, 784]\n",
    "X_test  = X_test.reshape(N_test, -1)\n",
    "\n",
    "# --------------- 2) Learning predictor ---------------\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.linear_model import SGDClassifier\n",
    "\n",
    "clf = make_pipeline(\n",
    "    StandardScaler(with_mean=True, with_std=True),\n",
    "    SGDClassifier(loss=\"log_loss\", max_iter=50, tol=1e-3, random_state=RANDOM_STATE)\n",
    ")\n",
    "clf.fit(X_train, y_train)\n",
    "\n",
    "# --------------- 3) Output of the predictor ---------------\n",
    "pred_probs_test = clf.predict_proba(X_test)  # shape: [N_test, 10]\n",
    "\n",
    "# --------------- 4) Extracting ambiguous data by cleanlab ---------------\n",
    "remove_idx = None\n",
    "keep_idx   = None\n",
    "scores_used = None\n",
    "\n",
    "try:\n",
    "    from cleanlab.rank import get_label_quality_scores\n",
    "    label_quality = get_label_quality_scores(labels=y_test, pred_probs=pred_probs_test)\n",
    "    # 1) Remove NaN\n",
    "    nan_idx = np.where(np.isnan(label_quality))[0]\n",
    "    n_nan = len(nan_idx)\n",
    "\n",
    "    # 2) Remove low score data\n",
    "    valid_mask = ~np.isnan(label_quality)\n",
    "    valid_idx  = np.where(valid_mask)[0]\n",
    "    valid_vals = label_quality[valid_mask]\n",
    "\n",
    "    q = REMOVE_FRAC\n",
    "    k = max(1, int(np.ceil(q * valid_vals.size)))\n",
    "    order = np.argsort(valid_vals)\n",
    "    low_idx = valid_idx[order[:k]]\n",
    "\n",
    "    # 3) Concatenation\n",
    "    remove_idx = np.unique(np.concatenate([nan_idx, low_idx]))\n",
    "    scores_used = \"label_quality_score (cleanlab.rank)\"\n",
    "except Exception as e:\n",
    "    from cleanlab.filter import find_label_issues\n",
    "    ranked = find_label_issues(\n",
    "        labels=y_test,\n",
    "        pred_probs=pred_probs_test,\n",
    "        return_indices_ranked_by=RANK_METRIC,\n",
    "    )\n",
    "    k = max(1, int(REMOVE_FRAC * len(y_test)))\n",
    "    remove_idx = np.array(ranked[:k], dtype=int)\n",
    "    keep_mask = np.ones(len(y_test), dtype=bool)\n",
    "    keep_mask[remove_idx] = False\n",
    "    keep_idx = np.where(keep_mask)[0]\n",
    "    scores_used = f\"find_label_issues ranked by {RANK_METRIC}\"\n",
    "\n",
    "# --------------- 5) Save ---------------\n",
    "np.save(OUTDIR / \"mnist_test_remove_indices.npy\", remove_idx)\n",
    "\n",
    "\n",
    "# --------------- 6) Visualize ambiguous data ---------------\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "n_show = min(50, len(remove_idx))\n",
    "fig, axes = plt.subplots(5, 10, figsize=(10, 5))\n",
    "axes = axes.ravel()\n",
    "for i in range(n_show):\n",
    "    idx = remove_idx[i]\n",
    "    img = test_ds.data[idx].numpy()\n",
    "    axes[i].imshow(img, cmap=\"gray\")\n",
    "    axes[i].axis(\"off\")\n",
    "\n",
    "    p = pred_probs_test[idx]\n",
    "    top2 = p.argsort()[-2:][::-1]\n",
    "    axes[i].set_title(f\"y={y_test[idx]}, ˆ={top2[0]}/{top2[1]}\", fontsize=8)\n",
    "for i in range(n_show, len(axes)):\n",
    "    axes[i].axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c04b9a81",
   "metadata": {},
   "source": [
    "## Evaluation by cleaned data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88568c7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.nn as nn\n",
    "import statistics as st\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from ipywidgets import FloatProgress\n",
    "\n",
    "from data import dataset\n",
    "\n",
    "def validate(model, dataloader, criterion, device):\n",
    "\n",
    "    with torch.no_grad():\n",
    "        model = model.to(device)\n",
    "        model.eval()\n",
    "        test_loss = torch.tensor(0.0).to(device)\n",
    "        test_preds = torch.tensor([]).float().to(device)\n",
    "        test_trues = torch.tensor([]).int().to(device)\n",
    "        correct = 0\n",
    "        total = 0\n",
    "        for i, data in enumerate(dataloader, 0):\n",
    "            xs, ys = data\n",
    "            xs = xs.to(device)\n",
    "            ys = ys.to(torch.int64)\n",
    "            ys = ys.to(device)\n",
    "            output = model(xs)\n",
    "            _, predicted = torch.max(output.data, 1)\n",
    "            correct += (predicted == ys).sum().item()\n",
    "            total += ys.size(0)\n",
    "            test_preds = torch.cat((test_preds, output))\n",
    "            test_trues = torch.cat((test_trues, ys))\n",
    "\n",
    "        test_loss = criterion(test_preds, test_trues)\n",
    "        err_rate = 100 * (1 - correct / total)\n",
    "\n",
    "    return test_loss, err_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6e42c91",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Subset\n",
    "\n",
    "train_dataset, valid_dataset = dataset()\n",
    "all_indices = list(range(len(valid_dataset)))\n",
    "remain_indices = np.load(\"mnist_clean_test/mnist_test_keep_indices.npy\")\n",
    "valid_dataset_filtered = Subset(valid_dataset, remain_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2442bd34",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "from model import LeNet, TwoLayerNeuralNet, ThreeLayerNeuralNet, FourLayerNeuralNet\n",
    "\n",
    "model_name = \"lenet\"\n",
    "sample_num = 2000\n",
    "data_num_list = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000, 45000, 50000, 55000, 60000]\n",
    "device = 0\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "batch_size = 1024\n",
    "\n",
    "validloader = DataLoader(valid_dataset_filtered,\n",
    "                batch_size=batch_size,\n",
    "                shuffle=False,\n",
    "                num_workers=0)\n",
    "\n",
    "if model_name == \"lenet\":\n",
    "    model = LeNet(1)\n",
    "if model_name == \"four_layer\":\n",
    "    model = FourLayerNeuralNet()\n",
    "if model_name == \"three_layer\":\n",
    "    model = ThreeLayerNeuralNet()\n",
    "if model_name == \"two_layer\":\n",
    "    model = TwoLayerNeuralNet()\n",
    "\n",
    "for data_num in data_num_list:\n",
    "\n",
    "    test_loss_list = []\n",
    "\n",
    "    for seed in tqdm(range(sample_num)):\n",
    "        checkpoint = torch.load(f\"../large_experiment_result/sgd/{model_name}/{data_num}/model_weights_{seed}.pth\")\n",
    "        state_dict = checkpoint[\"model_state\"]\n",
    "        model.load_state_dict(state_dict, strict=True)\n",
    "        model.eval()\n",
    "        test_loss, err_rate = validate(model, validloader, criterion, device)\n",
    "        test_loss_list.append(test_loss.item())\n",
    "    \n",
    "    print(f\"Data={data_num}, mean={st.mean(test_loss_list)}, std={st.stdev(test_loss_list)}\")\n",
    "    result = {\"test_loss_mean\": st.mean(test_loss_list), \"test_loss_std\": st.stdev(test_loss_list)}\n",
    "    folder = f\"../large_experiment_result/summary_clean_data/sgd/{model_name}/{data_num}\"\n",
    "    os.makedirs(folder, exist_ok=True)\n",
    "    with open(f\"{folder}/result.pcl\",'wb') as f:\n",
    "        pickle.dump(result, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
