{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = np.zeros([3,10])\n",
    "a[0,:] = np.random.randn(10)\n",
    "a[1,:] = np.array([10 for _ in range(10)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "aa = np.load('results_added_reg/feddf/cifar10/niid-labeldir/0.1/'\n",
    "             'logs_FedDF_resnet9_clients20_C20_10E_lr0.001_adam_1eKL_lr0.00001_adam_T3_publicCifar100_Gamma0.01_forgetting_rounds.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ -2.32750034,  -4.00249958,  -2.21249914,  -2.68249989,\n",
       "         1.02000046,   3.11000061,  -2.77499938,  -3.57000113,\n",
       "        -2.77500057,  -3.21000099,   4.84250021,   1.35500193,\n",
       "        12.32250023, -10.85000038,  -1.86999941,   4.19499826,\n",
       "        -8.34000111,  -9.68750095,  12.82249928,  -4.54749918,\n",
       "        -5.42749977,   3.18249989,  10.66499996,  11.31499863,\n",
       "         8.56999874,  24.39750099, -14.34500122,  13.27750015,\n",
       "        10.98999882,  13.72250175,  -0.87749958,  -9.96999931,\n",
       "         2.49249935,   2.26749992,  14.65499878,   3.49000168,\n",
       "        17.03749847,  17.42000198,  14.2699995 ,  11.60000038,\n",
       "        12.84499741,  -7.36000013,  22.42499924,  17.15249825,\n",
       "        -6.05749846,   3.43500042,   4.28750181,  -9.77250004,\n",
       "        -1.09999895,   5.60499859,  12.96000004,  13.62250233,\n",
       "         2.45000029,  13.96749973,   8.31500053,  -0.60250044,\n",
       "        18.61000061,   9.66249943,  -0.76750088,   8.20999718,\n",
       "        12.96749783,  -8.54999924,   0.89249897,  20.19999886,\n",
       "        -0.62499952,   8.71750069,  14.3949995 ,  10.52250004,\n",
       "        13.01249886,   7.71000004,   6.24749899,  21.4149971 ,\n",
       "        10.98499966,  21.47999954,   5.98499918,  -0.09000015,\n",
       "        13.6099987 ,  13.94000149,  12.06000137,  28.74250031,\n",
       "         8.34249878,   5.02499771,  14.79500008,  10.20250034,\n",
       "        12.5425005 ,  27.03999901,  10.75000095,  15.21749783,\n",
       "        12.91749954,  19.37749863,  14.26749897,   4.08500099,\n",
       "         6.81250095,  27.08750153,  30.02500153,   5.98250198,\n",
       "        15.6099987 ,  -0.54749823,   8.18749809])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "aa[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "34.0"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(90*2+20*8)/10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "40.0"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(40*10)/10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([354, 373, 987, 297, 714, 473, 250, 948, 779, 195, 414, 643, 564,\n",
       "       671, 398, 877, 826, 957, 831,  97, 947, 300, 633, 589, 728, 647,\n",
       "        19, 619, 319, 780, 672, 268, 383, 576, 257, 692, 673, 239, 260,\n",
       "        92, 339, 207, 325, 397, 541,  12, 937, 588, 193, 848, 923, 516,\n",
       "       349, 753, 705, 897, 867, 615, 411, 598, 655, 113, 849, 174, 667,\n",
       "        27, 631, 777, 772,  41,  76, 133, 924, 646, 103, 389, 597, 567,\n",
       "       519, 859, 289, 886, 478,  44, 962, 841, 258, 696, 620, 955, 856,\n",
       "       644,  71, 387, 320, 612, 592, 745, 139, 361, 159, 545, 874, 294,\n",
       "       602, 707, 680, 527, 968, 400, 364, 918, 575,  22, 759, 271, 578,\n",
       "       990, 345, 726,  86, 421,   5, 688, 255, 504, 689, 965, 976, 742,\n",
       "       560, 565, 447, 695, 109, 415, 181, 763, 862,  62, 678, 468, 507,\n",
       "       716,   8,  10, 106, 581, 348, 284, 407, 370, 248, 315, 332,   1,\n",
       "       996, 458, 138, 964, 529, 194, 160,  13, 221, 301, 899, 167, 548,\n",
       "       511, 896, 679, 380, 903, 730, 618, 249,   6,  58, 356, 944, 237,\n",
       "       238, 154, 773, 385, 977, 854, 699, 372, 582,  24, 926, 490, 192,\n",
       "       227, 898, 943, 915, 481, 438, 645, 869, 790, 857, 798, 342,  15,\n",
       "       233, 188, 218, 632, 316, 283, 823, 781, 254, 158, 467, 151, 343,\n",
       "       443, 277, 978, 608, 650,  28, 419, 934, 535, 963, 887, 938, 517,\n",
       "       690, 805, 795, 388, 128, 949, 486, 211, 881, 769, 952,  30, 701,\n",
       "       230, 126, 408, 102, 933,  46, 764, 549, 110])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.choice(1000, 256, replace=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0\n",
      "[0.0008181356214843422]\n",
      "epoch: 1\n",
      "[0.000473606797749979]\n",
      "epoch: 2\n",
      "[0.00018237254218789433]\n",
      "epoch: 3\n",
      "[2.6393202250021048e-05]\n",
      "epoch: 4\n",
      "[0.0]\n"
     ]
    }
   ],
   "source": [
    "model = torch.nn.Linear(10, 2)\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n",
    "steps = 10\n",
    "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5)\n",
    "\n",
    "for epoch in range(5):\n",
    "    print(f'epoch: {epoch}')\n",
    "    scheduler.step()\n",
    "    print(scheduler.get_lr())\n",
    "    for idx in range(steps):\n",
    "        _"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.randn(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 0.1640,  0.7710, -0.0922, -1.2683, -0.5726, -1.5794,  0.9146, -0.1605,\n",
       "         1.9307,  1.0921])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-8-e76fc59bc7c9>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  torch.nn.Softmax()(a)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0636, 0.1167, 0.0492, 0.0152, 0.0304, 0.0111, 0.1347, 0.0460, 0.3721,\n",
       "        0.1609])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.nn.Softmax()(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-9-e2c247f300fd>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  torch.nn.Softmax()(a/2)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0897, 0.1215, 0.0789, 0.0438, 0.0621, 0.0375, 0.1306, 0.0763, 0.2170,\n",
       "        0.1427])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.nn.Softmax()(a/2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-10-5755270e7ac1>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  torch.nn.Softmax()(a/3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0957, 0.1171, 0.0879, 0.0594, 0.0749, 0.0535, 0.1229, 0.0859, 0.1724,\n",
       "        0.1304])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.nn.Softmax()(a/3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-11-05324815c00f>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  torch.nn.Softmax()(a/4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0978, 0.1138, 0.0917, 0.0684, 0.0814, 0.0633, 0.1180, 0.0902, 0.1521,\n",
       "        0.1233])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.nn.Softmax()(a/4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-12-077bc8decd5a>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  torch.nn.Softmax()(a/5)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0988, 0.1115, 0.0938, 0.0742, 0.0852, 0.0697, 0.1148, 0.0926, 0.1406,\n",
       "        0.1189])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.nn.Softmax()(a/5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-13-c45b8b08a4d1>:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  torch.nn.Softmax()(a/10)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.0999, 0.1062, 0.0974, 0.0866, 0.0928, 0.0839, 0.1077, 0.0967, 0.1192,\n",
       "        0.1096])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.nn.Softmax()(a/10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([1.2319e-04, 1.3614e-04, 1.5046e-04, 9.9818e-01, 3.3485e-04, 2.0310e-04,\n",
      "        3.0299e-04, 2.7415e-04, 1.1146e-04, 1.8377e-04])\n",
      "tensor([0.0844, 0.0852, 0.0861, 0.2075, 0.0932, 0.0887, 0.0923, 0.0914, 0.0835,\n",
      "        0.0878])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-3ccc98adfb05>:3: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  print(torch.nn.Softmax()(b/1))\n",
      "<ipython-input-16-3ccc98adfb05>:4: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
      "  print(torch.nn.Softmax()(b/10))\n"
     ]
    }
   ],
   "source": [
    "b = torch.randn(10)\n",
    "b = torch.Tensor([1, 1.1, 1.2, 10, 2, 1.5, 1.9, 1.8, 0.9, 1.4])\n",
    "print(torch.nn.Softmax()(b/1))\n",
    "print(torch.nn.Softmax()(b/10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
