{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we present an example to **recover smoothed labels from untrained LeNet on cifar10**.\n",
    "\n",
    "Due to time limit, instead of 1000 examples, here we experiment 100 times with 10 samples for every class, and got **100%** accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/10 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 10%|█         | 1/10 [00:04<00:37,  4.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 2/10 [00:07<00:31,  3.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 7\n",
      "skip!\n",
      "epoch is 7\n",
      "epoch is 4\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 30%|███       | 3/10 [00:11<00:26,  3.77s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 4/10 [00:14<00:21,  3.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "epoch is 29\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 50%|█████     | 5/10 [00:19<00:19,  3.86s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 6/10 [00:23<00:15,  3.90s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "epoch is 30\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "epoch is 29\n",
      "epoch is 3\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "epoch is 30\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "epoch is 31\n",
      "flip!\n",
      "epoch is 7\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "epoch is 18\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 70%|███████   | 7/10 [00:29<00:14,  4.87s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 29\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 4\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 8/10 [00:33<00:08,  4.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 90%|█████████ | 9/10 [00:36<00:04,  4.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10/10 [00:40<00:00,  4.07s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import random\n",
    "from typing import OrderedDict\n",
    "from exp import cross_entropy_for_onehot\n",
    "from recovering import label_recovery\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from datetime import datetime\n",
    "import time\n",
    "seed=2023\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "torch.backends.cudnn.enabled = False\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True\n",
    "CONFIG=OrderedDict(device=torch.device('cpu'),\n",
    "    dataset=\"cifar10\",\n",
    "    network=\"lenet\",\n",
    "    opt=\"lbfgs\",\n",
    "    type='label_smooth',\n",
    "    pretrained=False,\n",
    "    criterion=cross_entropy_for_onehot,\n",
    "    lr=0.5,\n",
    "    bound=100,\n",
    "    iteration=200,\n",
    "    initia=1.,\n",
    "    coefficient=4)\n",
    "test=label_recovery(CONFIG)\n",
    "test.datadir='/home/yanbo.wang/'+test.datadir\n",
    "datalist=np.load('additional_files/mixup_list_cifar10.npy')\n",
    "# datalist=np.load('additional_files/mixup_list_imagenet.npy',allow_pickle=True)\n",
    "# datalist=np.load('additional_files/dataset_cifar100.csv',allow_pickle=True)\n",
    "exp=np.zeros((10,100,8))#index,prob,featureloss,real_scalar,reco_scalar,scalar_loss,success,time \n",
    "for i in tqdm(range(10)):\n",
    "    choice_index=np.random.choice(datalist[i],10)\n",
    "    for i_exp, ind in enumerate(choice_index):\n",
    "        prob=random.uniform(0,0.5)\n",
    "        #prob=0\n",
    "        test.setup(ind,prob)\n",
    "        exp[i,i_exp,0],exp[i,i_exp,1]=ind,prob\n",
    "        start_time=time.time()\n",
    "        exp[i,i_exp,6]=test.label_reco()\n",
    "        if exp[i,i_exp,6]==-1:\n",
    "            exp[i,i_exp,6]=test.pso()\n",
    "        exp[i,i_exp,7]=time.time()-start_time\n",
    "        exp[i,i_exp,3],exp[i,i_exp,4],exp[i,i_exp,5]=test.ground_truth,test.scalar,test.ground_truth-test.scalar\n",
    "        if exp[i,i_exp,6] ==1 or exp[i,i_exp,6] == 0:\n",
    "            exp[i,i_exp,2]=((test.recover_tensor-test.net.temp)**2).sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Under such settings, gradient-based algorithm (L-BFGS) is able to find the global optimum."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below is the tiny experiment for mixup. It takes about 18min to finish recover 135 examples.\n",
    "\n",
    "When L-BFGS did not get the optimal scalar, PSO is called to search iteratively. Adjusting searching bound, seed population as well as every searching range will alter the running time and searching accuracy (and error). It is a time-accuracy trade-off. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/45 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  2%|▏         | 1/45 [00:01<00:46,  1.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  4%|▍         | 2/45 [00:02<00:54,  1.26s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  7%|▋         | 3/45 [00:03<00:48,  1.16s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|▉         | 4/45 [00:04<00:46,  1.13s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 11%|█         | 5/45 [00:05<00:46,  1.17s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "skip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 13%|█▎        | 6/45 [00:06<00:43,  1.12s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 10\n",
      "flip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "scalar is tensor(-129.6024, requires_grad=True) while gt is -2.5163229\n",
      "out of bound!\n",
      "ground_truth: -2.5163228511810303\n",
      "searching from 0.7 to 6.0!\n",
      "[6.] [12.90142536]\n",
      "searching from -6.3 to -1.0!\n",
      "[-2.51632335] [8.97517683e-13]\n",
      "successfully find the ground_truth [-2.51632335]\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 16%|█▌        | 7/45 [02:15<27:09, 42.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 18%|█▊        | 8/45 [02:16<18:15, 29.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 20%|██        | 9/45 [02:17<12:23, 20.66s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 22%|██▏       | 10/45 [02:18<08:31, 14.60s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 24%|██▍       | 11/45 [02:19<05:55, 10.45s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 27%|██▋       | 12/45 [02:20<04:08,  7.52s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 29%|██▉       | 13/45 [02:22<02:59,  5.62s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 31%|███       | 14/45 [02:23<02:13,  4.32s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 29\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 15/45 [02:24<01:41,  3.37s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 36%|███▌      | 16/45 [02:25<01:19,  2.75s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 38%|███▊      | 17/45 [02:26<01:02,  2.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "scalar is tensor(-129.5270, requires_grad=True) while gt is -2.0672958\n",
      "out of bound!\n",
      "ground_truth: -2.067295789718628\n",
      "searching from 0.7 to 6.0!\n",
      "[6.] [12.40940857]\n",
      "searching from -6.3 to -1.0!\n",
      "[-2.06728735] [1.35953809e-11]\n",
      "successfully find the ground_truth [-2.06728735]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 40%|████      | 18/45 [04:35<18:06, 40.24s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 42%|████▏     | 19/45 [04:36<12:22, 28.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 44%|████▍     | 20/45 [04:38<08:29, 20.39s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "skip!\n",
      "flip!\n",
      "epoch is 10\n",
      "skip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 47%|████▋     | 21/45 [04:39<05:52, 14.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 10\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "skip!\n",
      "scalar is tensor(-118.8092, requires_grad=True) while gt is -4.2833943\n",
      "out of bound!\n",
      "ground_truth: -4.2833943367004395\n",
      "searching from 0.7 to 6.0!\n",
      "[6.] [13.31690311]\n",
      "searching from -6.3 to -1.0!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 49%|████▉     | 22/45 [06:48<18:48, 49.06s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-4.28339318] [9.24150321e-16]\n",
      "successfully find the ground_truth [-4.28339318]\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 51%|█████     | 23/45 [06:49<12:41, 34.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 53%|█████▎    | 24/45 [06:50<08:35, 24.55s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 56%|█████▌    | 25/45 [06:51<05:49, 17.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 58%|█████▊    | 26/45 [06:52<03:57, 12.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 60%|██████    | 27/45 [06:53<02:42,  9.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 4\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 62%|██████▏   | 28/45 [06:54<01:51,  6.56s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 64%|██████▍   | 29/45 [06:55<01:18,  4.91s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|██████▋   | 30/45 [06:56<00:56,  3.74s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|██████▉   | 31/45 [06:57<00:42,  3.02s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "epoch is 29\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 71%|███████   | 32/45 [06:58<00:31,  2.41s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "epoch is 3\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 73%|███████▎  | 33/45 [06:59<00:22,  1.92s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 76%|███████▌  | 34/45 [07:00<00:18,  1.67s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 78%|███████▊  | 35/45 [07:01<00:14,  1.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 3\n",
      "skip!\n",
      "flip!\n",
      "epoch is 10\n",
      "skip!\n",
      "flip!\n",
      "epoch is 10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 80%|████████  | 36/45 [07:03<00:13,  1.48s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "skip!\n",
      "scalar is tensor(-192.0335, requires_grad=True) while gt is -7.2275977\n",
      "out of bound!\n",
      "ground_truth: -7.227597713470459\n",
      "searching from 0.7 to 6.0!\n",
      "[6.] [15.23548412]\n",
      "searching from -6.3 to -1.0!\n",
      "[-6.3] [0.02103045]\n",
      "searching from 5.7 to 16.0!\n",
      "[16.] [1.88450933]\n",
      "searching from -16.3 to -6.0!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 82%|████████▏ | 37/45 [11:28<10:45, 80.70s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-7.22757834] [2.70919528e-12]\n",
      "successfully find the ground_truth [-7.22757834]\n",
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 84%|████████▍ | 38/45 [11:29<06:37, 56.85s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 6\n",
      "epoch is 3\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 87%|████████▋ | 39/45 [11:30<04:00, 40.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "flip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "scalar is tensor(-130.0161, requires_grad=True) while gt is -6.4449916\n",
      "out of bound!\n",
      "ground_truth: -6.444991588592529\n",
      "searching from 0.7 to 6.0!\n",
      "[3.87603519] [1.33605087]\n",
      "searching from -6.3 to -1.0!\n",
      "[-6.3] [0.00056831]\n",
      "searching from 5.7 to 16.0!\n",
      "[16.] [0.59575492]\n",
      "searching from -16.3 to -6.0!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 89%|████████▉ | 40/45 [15:57<09:00, 108.09s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-6.44498116] [9.10413888e-13]\n",
      "successfully find the ground_truth [-6.44498116]\n",
      "skip!\n",
      "flip!\n",
      "epoch is 10\n",
      "skip!\n",
      "flip!\n",
      "epoch is 10\n",
      "skip!\n",
      "flip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 91%|█████████ | 41/45 [15:59<05:04, 76.23s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch is 10\n",
      "flip!\n",
      "epoch is 6\n",
      "skip!\n",
      "flip!\n",
      "epoch is 10\n",
      "skip!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 93%|█████████▎| 42/45 [16:01<02:41, 53.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 10\n",
      "epoch is 3\n",
      "epoch is 3\n",
      "flip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "skip!\n",
      "scalar is tensor(-129.5073, requires_grad=True) while gt is -2.0390737\n",
      "out of bound!\n",
      "ground_truth: -2.0390737056732178\n",
      "searching from 0.7 to 6.0!\n",
      "[6.] [12.89249039]\n",
      "searching from -6.3 to -1.0!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 96%|█████████▌| 43/45 [18:08<02:31, 75.82s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-2.03905833] [3.33778873e-11]\n",
      "successfully find the ground_truth [-2.03905833]\n",
      "flip!\n",
      "epoch is 6\n",
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 98%|█████████▊| 44/45 [18:09<00:53, 53.44s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n",
      "epoch is 3\n",
      "epoch is 3\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 45/45 [18:10<00:00, 24.23s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "flip!\n",
      "epoch is 6\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from itertools import combinations\n",
    "seed=2023\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "CONFIG=OrderedDict(device=torch.device('cpu'),\n",
    "    dataset=\"cifar10\",\n",
    "    network=\"lenet\",\n",
    "    opt=\"lbfgs\",\n",
    "    type='mixup',\n",
    "    pretrained=False,\n",
    "    criterion=cross_entropy_for_onehot,\n",
    "    lr=0.5,\n",
    "    bound=100,\n",
    "    iteration=200,\n",
    "    initia=1.,\n",
    "    coefficient=2)\n",
    "test=label_recovery(CONFIG)\n",
    "test.datadir='/home/yanbo.wang/'+test.datadir\n",
    "mixup_list=np.load('additional_files/mixup_list_cifar10.npy')\n",
    "exp=np.zeros((45,20,9))#index,prob,featureloss,real_scalar,reco_scalar,scalar_loss,success,time \n",
    "combination_list=list(combinations(range(10),2))\n",
    "for i in tqdm(range(45)):\n",
    "    item=combination_list[i]\n",
    "    for ii in range(3):# it is ok to change the number of experiments for every combination. Currently we have 3*45=135 experiments in total.\n",
    "        ind=[random.randint(0,999),random.randint(0,999)]\n",
    "        prob=random.uniform(0,1)\n",
    "        test.setup([mixup_list[item[0],ind[0]],mixup_list[item[1],ind[1]]],[1-prob,prob])\n",
    "        exp[i,ii,0],exp[i,ii,1],exp[i,ii,2]=ind[0],ind[1],prob\n",
    "        start_time=time.time()\n",
    "        exp[i,ii,7]=test.label_reco()\n",
    "        if exp[i,ii,7]==-1:\n",
    "            exp[i,ii,7]=test.pso(50)\n",
    "        exp[i,ii,8]=time.time()-start_time\n",
    "        exp[i,ii,4],exp[i,ii,5],exp[i,ii,6]=test.ground_truth,test.scalar,test.ground_truth-test.scalar\n",
    "        if exp[i,ii,7] ==1 or exp[i,ii,7] == 0:\n",
    "            exp[i,ii,3]=((test.recover_tensor-test.net.temp)**2).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "count=0\n",
    "for i in exp[:,:,7].reshape((-1)):\n",
    "    if i<0:\n",
    "        count+=1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "count"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "292c41a0a097fd957a8e595e8d36b0490112cfe5dc0f0435f389dbe31f767a6d"
  },
  "kernelspec": {
   "display_name": "Python 3.9.13 ('cripac': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
