{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "123 32561\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import copy\n",
    "from model_SSAGDA import Model\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from dataclass import Creatdata\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "torch.cuda.get_device_name(0)\n",
    "\n",
    "data_name = 'a9a'\n",
    "\n",
    "is_create_data = True\n",
    "\n",
    "if is_create_data:\n",
    "    data_path = './data/'+ data_name +'.py'\n",
    "    exec(open(data_path).read())\n",
    "else:\n",
    "    file_name = './data/' + data_name + '/' + data_name\n",
    "    with open(file_name, \"rb\") as fp:   # Unpickling\n",
    "        train_set = pickle.load(fp)\n",
    "\n",
    "train_set.data = train_set.data.to(device)\n",
    "train_set.targets = train_set.targets.to(device)\n",
    "\n",
    "# print(len(train_set.data[0]), len(train_set.targets))\n",
    "\n",
    "\n",
    "# Take a random sample of 1000 data points to reduce computational cost for now\n",
    "# Set the seed for reproducibility\n",
    "np.random.seed(13)\n",
    "# Extract data and targets\n",
    "data = train_set.data\n",
    "targets = train_set.targets\n",
    "\n",
    "# Create indices to sample from\n",
    "indices = np.arange(len(targets))\n",
    "\n",
    "# Select 1000 random indices\n",
    "random_indices = np.random.choice(indices, 10000, replace=False)\n",
    "\n",
    "# Sample the data and targets using the random indices\n",
    "sampled_data = data[random_indices]\n",
    "sampled_targets = targets[random_indices]\n",
    "\n",
    "# Create a new Creatdata object with the sampled data and targets\n",
    "sampled_train_set = Creatdata(data=sampled_data, targets=sampled_targets)\n",
    "# sampled_train_set = train_set\n",
    "\n",
    "# Move data and targets to device if needed\n",
    "sampled_train_set.data = sampled_train_set.data.to(device)\n",
    "sampled_train_set.targets = sampled_train_set.targets.to(device)\n",
    "\n",
    "# train_set = sampled_train_set\n",
    "\n",
    "print(len(train_set.data[0]),len(train_set.targets)) # printing dimensions of data: features x observations\n",
    "\n",
    "# print('Example of the data')\n",
    "# print(train_set.data[0]) # entire first row of data in the sparse dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "Sample complexity: 31868, Epoch: 1, Accuracy: 0.7028346657752991, Loss: 0.6016443967819214\n",
      "Sample complexity: 63736, Epoch: 2, Accuracy: 0.7032032012939453, Loss: 0.6012693047523499\n",
      "Sample complexity: 95604, Epoch: 3, Accuracy: 0.7041245698928833, Loss: 0.6008415818214417\n",
      "Sample complexity: 127472, Epoch: 4, Accuracy: 0.704646646976471, Loss: 0.600554347038269\n",
      "Sample complexity: 159340, Epoch: 5, Accuracy: 0.7048923373222351, Loss: 0.6001132130622864\n",
      "Sample complexity: 191208, Epoch: 6, Accuracy: 0.7055987119674683, Loss: 0.5996447801589966\n",
      "Sample complexity: 223076, Epoch: 7, Accuracy: 0.7062436938285828, Loss: 0.5992267727851868\n",
      "Sample complexity: 254944, Epoch: 8, Accuracy: 0.7073186039924622, Loss: 0.5987934470176697\n",
      "Sample complexity: 286812, Epoch: 9, Accuracy: 0.7078713774681091, Loss: 0.5984992980957031\n",
      "Sample complexity: 318680, Epoch: 10, Accuracy: 0.708547055721283, Loss: 0.5980468392372131\n",
      "Sample complexity: 350548, Epoch: 11, Accuracy: 0.7094376683235168, Loss: 0.5975379347801208\n",
      "Sample complexity: 382416, Epoch: 12, Accuracy: 0.7095605134963989, Loss: 0.597245991230011\n",
      "Sample complexity: 414284, Epoch: 13, Accuracy: 0.7100826501846313, Loss: 0.5968517661094666\n",
      "Sample complexity: 446152, Epoch: 14, Accuracy: 0.7107582688331604, Loss: 0.5963341593742371\n",
      "Sample complexity: 478020, Epoch: 15, Accuracy: 0.7113724946975708, Loss: 0.5960264205932617\n",
      "Sample complexity: 509888, Epoch: 16, Accuracy: 0.7122324109077454, Loss: 0.5955594778060913\n",
      "Sample complexity: 541756, Epoch: 18, Accuracy: 0.7126317024230957, Loss: 0.5952829122543335\n",
      "Sample complexity: 573624, Epoch: 19, Accuracy: 0.7128773927688599, Loss: 0.5949166417121887\n",
      "Sample complexity: 605492, Epoch: 20, Accuracy: 0.7133073210716248, Loss: 0.5946007966995239\n",
      "Sample complexity: 637360, Epoch: 21, Accuracy: 0.7138601541519165, Loss: 0.5942851305007935\n",
      "Sample complexity: 669228, Epoch: 22, Accuracy: 0.7143515348434448, Loss: 0.5940051674842834\n",
      "Sample complexity: 701096, Epoch: 23, Accuracy: 0.7147507667541504, Loss: 0.5936076045036316\n",
      "Sample complexity: 732964, Epoch: 24, Accuracy: 0.7152114510536194, Loss: 0.5933058261871338\n",
      "Sample complexity: 764832, Epoch: 25, Accuracy: 0.7155799865722656, Loss: 0.5928938984870911\n",
      "Sample complexity: 796700, Epoch: 26, Accuracy: 0.7158871293067932, Loss: 0.5926135182380676\n",
      "Sample complexity: 828568, Epoch: 27, Accuracy: 0.7168391942977905, Loss: 0.5922033786773682\n",
      "Sample complexity: 860436, Epoch: 28, Accuracy: 0.7173612713813782, Loss: 0.5919708013534546\n",
      "Sample complexity: 892304, Epoch: 29, Accuracy: 0.7180061936378479, Loss: 0.5916435122489929\n",
      "Sample complexity: 924172, Epoch: 30, Accuracy: 0.7184668779373169, Loss: 0.5913814306259155\n",
      "Sample complexity: 956040, Epoch: 31, Accuracy: 0.7192960977554321, Loss: 0.5910992622375488\n",
      "Sample complexity: 987908, Epoch: 32, Accuracy: 0.720094621181488, Loss: 0.5907809734344482\n",
      "Sample complexity: 1019776, Epoch: 33, Accuracy: 0.7203095555305481, Loss: 0.5905025601387024\n",
      "Sample complexity: 1051644, Epoch: 35, Accuracy: 0.721108078956604, Loss: 0.5901113152503967\n",
      "Sample complexity: 1083512, Epoch: 36, Accuracy: 0.7220908403396606, Loss: 0.5897594094276428\n",
      "Sample complexity: 1115380, Epoch: 37, Accuracy: 0.7227358222007751, Loss: 0.5894366502761841\n",
      "Sample complexity: 1147248, Epoch: 38, Accuracy: 0.7234114408493042, Loss: 0.5890186429023743\n",
      "Sample complexity: 1179116, Epoch: 39, Accuracy: 0.7242406606674194, Loss: 0.5886539816856384\n",
      "Sample complexity: 1210984, Epoch: 40, Accuracy: 0.7246091961860657, Loss: 0.5882541537284851\n",
      "Sample complexity: 1242852, Epoch: 41, Accuracy: 0.7253155708312988, Loss: 0.5880146622657776\n",
      "Sample complexity: 1274720, Epoch: 42, Accuracy: 0.7257762551307678, Loss: 0.5877249240875244\n",
      "Sample complexity: 1306588, Epoch: 43, Accuracy: 0.7260526418685913, Loss: 0.5875520706176758\n",
      "Sample complexity: 1338456, Epoch: 44, Accuracy: 0.726482629776001, Loss: 0.587290346622467\n",
      "Sample complexity: 1370324, Epoch: 45, Accuracy: 0.7271275520324707, Loss: 0.5870878100395203\n",
      "Sample complexity: 1402192, Epoch: 46, Accuracy: 0.7276496291160583, Loss: 0.5867314338684082\n",
      "Sample complexity: 1434060, Epoch: 47, Accuracy: 0.7281717658042908, Loss: 0.5863534212112427\n",
      "Sample complexity: 1465928, Epoch: 48, Accuracy: 0.7289702296257019, Loss: 0.5861106514930725\n",
      "Sample complexity: 1497796, Epoch: 49, Accuracy: 0.7293695211410522, Loss: 0.5857903361320496\n",
      "Sample complexity: 1529664, Epoch: 50, Accuracy: 0.7301065921783447, Loss: 0.5855231881141663\n",
      "Sample complexity: 1561532, Epoch: 52, Accuracy: 0.730597972869873, Loss: 0.5851919651031494\n",
      "Sample complexity: 1593400, Epoch: 53, Accuracy: 0.73115074634552, Loss: 0.5847755670547485\n",
      "Sample complexity: 1625268, Epoch: 54, Accuracy: 0.7318571209907532, Loss: 0.5845256447792053\n",
      "Sample complexity: 1657136, Epoch: 55, Accuracy: 0.7324099540710449, Loss: 0.5843088030815125\n",
      "Sample complexity: 1689004, Epoch: 56, Accuracy: 0.7326863408088684, Loss: 0.5840791463851929\n",
      "Sample complexity: 1720872, Epoch: 57, Accuracy: 0.7333313226699829, Loss: 0.5837574601173401\n",
      "Sample complexity: 1752740, Epoch: 58, Accuracy: 0.7337919473648071, Loss: 0.5834732055664062\n",
      "Sample complexity: 1784608, Epoch: 59, Accuracy: 0.7342219352722168, Loss: 0.5831705331802368\n",
      "Sample complexity: 1816476, Epoch: 60, Accuracy: 0.7345290184020996, Loss: 0.5828700065612793\n",
      "Sample complexity: 1848344, Epoch: 61, Accuracy: 0.7349897027015686, Loss: 0.5825756788253784\n",
      "Sample complexity: 1880212, Epoch: 62, Accuracy: 0.7350818514823914, Loss: 0.5824406147003174\n",
      "Sample complexity: 1912080, Epoch: 63, Accuracy: 0.735603928565979, Loss: 0.5822958946228027\n",
      "Sample complexity: 1943948, Epoch: 64, Accuracy: 0.7360953092575073, Loss: 0.5820702314376831\n",
      "Sample complexity: 1975816, Epoch: 65, Accuracy: 0.7363103032112122, Loss: 0.5819084644317627\n",
      "Sample complexity: 2007684, Epoch: 66, Accuracy: 0.7369245290756226, Loss: 0.5817224979400635\n",
      "Sample complexity: 2039552, Epoch: 67, Accuracy: 0.737538754940033, Loss: 0.5814039707183838\n",
      "Sample complexity: 2071420, Epoch: 69, Accuracy: 0.738030195236206, Loss: 0.5810360312461853\n",
      "Sample complexity: 2103288, Epoch: 70, Accuracy: 0.7389515042304993, Loss: 0.5807533860206604\n",
      "Sample complexity: 2135156, Epoch: 71, Accuracy: 0.7394121885299683, Loss: 0.5804374814033508\n",
      "Sample complexity: 2167024, Epoch: 72, Accuracy: 0.7399035692214966, Loss: 0.5802150964736938\n",
      "Sample complexity: 2198892, Epoch: 73, Accuracy: 0.7399342656135559, Loss: 0.5801325440406799\n",
      "Sample complexity: 2230760, Epoch: 74, Accuracy: 0.7403028011322021, Loss: 0.5798237323760986\n",
      "Sample complexity: 2262628, Epoch: 75, Accuracy: 0.7401799559593201, Loss: 0.5796943306922913\n",
      "Sample complexity: 2294496, Epoch: 76, Accuracy: 0.7402414083480835, Loss: 0.5794860124588013\n",
      "Sample complexity: 2326364, Epoch: 77, Accuracy: 0.7406713366508484, Loss: 0.579300582408905\n",
      "Sample complexity: 2358232, Epoch: 78, Accuracy: 0.7409170866012573, Loss: 0.579181432723999\n",
      "Sample complexity: 2390100, Epoch: 79, Accuracy: 0.7413163185119629, Loss: 0.5789765119552612\n",
      "Sample complexity: 2421968, Epoch: 80, Accuracy: 0.741562008857727, Loss: 0.5787770748138428\n",
      "Sample complexity: 2453836, Epoch: 81, Accuracy: 0.7418383955955505, Loss: 0.5786786079406738\n",
      "Sample complexity: 2485704, Epoch: 82, Accuracy: 0.7422990798950195, Loss: 0.5783993005752563\n",
      "Sample complexity: 2517572, Epoch: 83, Accuracy: 0.7426062226295471, Loss: 0.5781759023666382\n",
      "Sample complexity: 2549440, Epoch: 84, Accuracy: 0.7433739900588989, Loss: 0.5778017044067383\n",
      "Sample complexity: 2581308, Epoch: 86, Accuracy: 0.7435889840126038, Loss: 0.5775876045227051\n",
      "Sample complexity: 2613176, Epoch: 87, Accuracy: 0.7442032098770142, Loss: 0.5772256255149841\n",
      "Sample complexity: 2645044, Epoch: 88, Accuracy: 0.7443260550498962, Loss: 0.5769753456115723\n",
      "Sample complexity: 2676912, Epoch: 89, Accuracy: 0.7448481321334839, Loss: 0.5766979455947876\n",
      "Sample complexity: 2708780, Epoch: 90, Accuracy: 0.7452781200408936, Loss: 0.576409637928009\n",
      "Sample complexity: 2740648, Epoch: 91, Accuracy: 0.7457695007324219, Loss: 0.5762325525283813\n",
      "Sample complexity: 2772516, Epoch: 92, Accuracy: 0.7458001971244812, Loss: 0.5759226679801941\n",
      "Sample complexity: 2804384, Epoch: 93, Accuracy: 0.7461687326431274, Loss: 0.5759099125862122\n",
      "Sample complexity: 2836252, Epoch: 94, Accuracy: 0.7465987205505371, Loss: 0.5758219361305237\n",
      "Sample complexity: 2868120, Epoch: 95, Accuracy: 0.746567964553833, Loss: 0.5756472945213318\n",
      "Sample complexity: 2899988, Epoch: 96, Accuracy: 0.7467522621154785, Loss: 0.5753886699676514\n",
      "Sample complexity: 2931856, Epoch: 97, Accuracy: 0.7469979524612427, Loss: 0.5752816200256348\n",
      "Sample complexity: 2963724, Epoch: 98, Accuracy: 0.7473664879798889, Loss: 0.575096070766449\n",
      "Sample complexity: 2995592, Epoch: 99, Accuracy: 0.7479193210601807, Loss: 0.5747303366661072\n",
      "Sample complexity: 3027460, Epoch: 100, Accuracy: 0.7483799457550049, Loss: 0.5743346214294434\n",
      "Sample complexity: 3059328, Epoch: 101, Accuracy: 0.7485949397087097, Loss: 0.5742196440696716\n",
      "Sample complexity: 3091196, Epoch: 103, Accuracy: 0.748871386051178, Loss: 0.5740634799003601\n",
      "Sample complexity: 3123064, Epoch: 104, Accuracy: 0.7493320107460022, Loss: 0.5736215710639954\n",
      "Sample complexity: 3154932, Epoch: 105, Accuracy: 0.7495163083076477, Loss: 0.5734434127807617\n",
      "Sample complexity: 3186800, Epoch: 106, Accuracy: 0.7498233914375305, Loss: 0.5733757019042969\n",
      "Sample complexity: 3218668, Epoch: 107, Accuracy: 0.7500383853912354, Loss: 0.5731350779533386\n",
      "Sample complexity: 3250536, Epoch: 108, Accuracy: 0.7502533793449402, Loss: 0.5728501677513123\n",
      "Sample complexity: 3282404, Epoch: 109, Accuracy: 0.7505912184715271, Loss: 0.5725864768028259\n",
      "Sample complexity: 3314272, Epoch: 110, Accuracy: 0.7507447600364685, Loss: 0.5724552869796753\n",
      "Sample complexity: 3346140, Epoch: 111, Accuracy: 0.7511132955551147, Loss: 0.572185218334198\n",
      "Sample complexity: 3378008, Epoch: 112, Accuracy: 0.7512668371200562, Loss: 0.5721029043197632\n",
      "Sample complexity: 3409876, Epoch: 113, Accuracy: 0.7513282895088196, Loss: 0.5718919634819031\n",
      "Sample complexity: 3441744, Epoch: 114, Accuracy: 0.7516046762466431, Loss: 0.5717766880989075\n",
      "Sample complexity: 3473612, Epoch: 115, Accuracy: 0.7516046762466431, Loss: 0.5716057419776917\n",
      "Sample complexity: 3505480, Epoch: 116, Accuracy: 0.75194251537323, Loss: 0.5714690089225769\n",
      "Sample complexity: 3537348, Epoch: 117, Accuracy: 0.7520960569381714, Loss: 0.5713621973991394\n",
      "Sample complexity: 3569216, Epoch: 118, Accuracy: 0.7523417472839355, Loss: 0.5712111592292786\n",
      "Sample complexity: 3601084, Epoch: 120, Accuracy: 0.7524953484535217, Loss: 0.5710111260414124\n",
      "Sample complexity: 3632952, Epoch: 121, Accuracy: 0.7527717351913452, Loss: 0.5709105134010315\n",
      "Sample complexity: 3664820, Epoch: 122, Accuracy: 0.7531402707099915, Loss: 0.5707228779792786\n",
      "Sample complexity: 3696688, Epoch: 123, Accuracy: 0.7534166574478149, Loss: 0.5705546736717224\n",
      "Sample complexity: 3728556, Epoch: 124, Accuracy: 0.7536009550094604, Loss: 0.5703352093696594\n",
      "Sample complexity: 3760424, Epoch: 125, Accuracy: 0.7535088062286377, Loss: 0.5701292753219604\n",
      "Sample complexity: 3792292, Epoch: 126, Accuracy: 0.7538773417472839, Loss: 0.5699745416641235\n",
      "Sample complexity: 3824160, Epoch: 127, Accuracy: 0.754000186920166, Loss: 0.5697688460350037\n",
      "Sample complexity: 3856028, Epoch: 128, Accuracy: 0.7543687224388123, Loss: 0.569639265537262\n",
      "Sample complexity: 3887896, Epoch: 129, Accuracy: 0.754460871219635, Loss: 0.5693970322608948\n",
      "Sample complexity: 3919764, Epoch: 130, Accuracy: 0.7548294067382812, Loss: 0.5693217515945435\n",
      "Sample complexity: 3951632, Epoch: 131, Accuracy: 0.7548908591270447, Loss: 0.5691862106323242\n",
      "Sample complexity: 3983500, Epoch: 132, Accuracy: 0.7550750970840454, Loss: 0.5690916776657104\n",
      "Sample complexity: 4015368, Epoch: 133, Accuracy: 0.7551979422569275, Loss: 0.5690122246742249\n",
      "Sample complexity: 4047236, Epoch: 134, Accuracy: 0.7553514838218689, Loss: 0.5688763856887817\n",
      "Sample complexity: 4079104, Epoch: 135, Accuracy: 0.7557200193405151, Loss: 0.568734884262085\n",
      "Sample complexity: 4110972, Epoch: 137, Accuracy: 0.7557200193405151, Loss: 0.5685872435569763\n",
      "Sample complexity: 4142840, Epoch: 138, Accuracy: 0.7558428645133972, Loss: 0.568535327911377\n",
      "Sample complexity: 4174708, Epoch: 139, Accuracy: 0.7558736205101013, Loss: 0.5682353377342224\n",
      "Sample complexity: 4206576, Epoch: 140, Accuracy: 0.7558736205101013, Loss: 0.5680826306343079\n",
      "Sample complexity: 4238444, Epoch: 141, Accuracy: 0.7558736205101013, Loss: 0.5680060386657715\n",
      "Sample complexity: 4270312, Epoch: 142, Accuracy: 0.7559657096862793, Loss: 0.5679141879081726\n",
      "Sample complexity: 4302180, Epoch: 143, Accuracy: 0.7560885548591614, Loss: 0.5677323341369629\n",
      "Sample complexity: 4334048, Epoch: 144, Accuracy: 0.7562114000320435, Loss: 0.5675981640815735\n",
      "Sample complexity: 4365916, Epoch: 145, Accuracy: 0.7563035488128662, Loss: 0.5673990249633789\n",
      "Sample complexity: 4397784, Epoch: 146, Accuracy: 0.756395697593689, Loss: 0.5671690106391907\n",
      "Sample complexity: 4429652, Epoch: 147, Accuracy: 0.7565799951553345, Loss: 0.5670924782752991\n",
      "Sample complexity: 4461520, Epoch: 148, Accuracy: 0.7566413879394531, Loss: 0.5671032071113586\n",
      "Sample complexity: 4493388, Epoch: 149, Accuracy: 0.7566720843315125, Loss: 0.5669963359832764\n",
      "Sample complexity: 4525256, Epoch: 150, Accuracy: 0.7567028403282166, Loss: 0.5668401122093201\n",
      "Sample complexity: 4557124, Epoch: 151, Accuracy: 0.7568870782852173, Loss: 0.5667643547058105\n",
      "Sample complexity: 4588992, Epoch: 152, Accuracy: 0.7570406198501587, Loss: 0.5666359663009644\n",
      "Sample complexity: 4620860, Epoch: 154, Accuracy: 0.7571634650230408, Loss: 0.5664663910865784\n",
      "Sample complexity: 4652728, Epoch: 155, Accuracy: 0.7571020722389221, Loss: 0.5662837028503418\n",
      "Sample complexity: 4684596, Epoch: 156, Accuracy: 0.7570713758468628, Loss: 0.5660613775253296\n",
      "Sample complexity: 4716464, Epoch: 157, Accuracy: 0.7571942210197449, Loss: 0.5659250020980835\n",
      "Sample complexity: 4748332, Epoch: 158, Accuracy: 0.7572863101959229, Loss: 0.5657804608345032\n",
      "Sample complexity: 4780200, Epoch: 159, Accuracy: 0.7574091553688049, Loss: 0.5656800866127014\n",
      "Sample complexity: 4812068, Epoch: 160, Accuracy: 0.7574706077575684, Loss: 0.5655263662338257\n",
      "Sample complexity: 4843936, Epoch: 161, Accuracy: 0.7575934529304504, Loss: 0.5654192566871643\n",
      "Sample complexity: 4875804, Epoch: 162, Accuracy: 0.7576856017112732, Loss: 0.5653104186058044\n",
      "Sample complexity: 4907672, Epoch: 163, Accuracy: 0.7577776908874512, Loss: 0.5651323199272156\n",
      "Sample complexity: 4939540, Epoch: 164, Accuracy: 0.7577776908874512, Loss: 0.565066397190094\n",
      "Sample complexity: 4971408, Epoch: 165, Accuracy: 0.7578391432762146, Loss: 0.5649166107177734\n",
      "Sample complexity: 5003276, Epoch: 166, Accuracy: 0.7579312920570374, Loss: 0.5648303031921387\n",
      "Sample complexity: 5035144, Epoch: 167, Accuracy: 0.757992684841156, Loss: 0.5646982789039612\n",
      "Sample complexity: 5067012, Epoch: 168, Accuracy: 0.7582076787948608, Loss: 0.5646569132804871\n",
      "Sample complexity: 5098880, Epoch: 169, Accuracy: 0.7582691311836243, Loss: 0.5647096037864685\n",
      "Sample complexity: 5130748, Epoch: 171, Accuracy: 0.7583305239677429, Loss: 0.5646061301231384\n",
      "Sample complexity: 5162616, Epoch: 172, Accuracy: 0.758453369140625, Loss: 0.5645179152488708\n",
      "Sample complexity: 5194484, Epoch: 173, Accuracy: 0.7585455179214478, Loss: 0.5644465684890747\n",
      "Sample complexity: 5226352, Epoch: 174, Accuracy: 0.7585762143135071, Loss: 0.5644772052764893\n",
      "Sample complexity: 5258220, Epoch: 175, Accuracy: 0.7586376667022705, Loss: 0.5641839504241943\n",
      "Sample complexity: 5290088, Epoch: 176, Accuracy: 0.7586990594863892, Loss: 0.5640590786933899\n",
      "Sample complexity: 5321956, Epoch: 177, Accuracy: 0.7587605118751526, Loss: 0.5639199018478394\n",
      "Sample complexity: 5353824, Epoch: 178, Accuracy: 0.7587605118751526, Loss: 0.5638093948364258\n",
      "Sample complexity: 5385692, Epoch: 179, Accuracy: 0.7587297558784485, Loss: 0.5636578798294067\n",
      "Sample complexity: 5417560, Epoch: 180, Accuracy: 0.7588526010513306, Loss: 0.5635649561882019\n",
      "Sample complexity: 5449428, Epoch: 181, Accuracy: 0.7589447498321533, Loss: 0.5635002255439758\n",
      "Sample complexity: 5481296, Epoch: 182, Accuracy: 0.7589754462242126, Loss: 0.5634432435035706\n",
      "Sample complexity: 5513164, Epoch: 183, Accuracy: 0.758914053440094, Loss: 0.5633431673049927\n",
      "Sample complexity: 5545032, Epoch: 184, Accuracy: 0.7589754462242126, Loss: 0.5632684826850891\n",
      "Sample complexity: 5576900, Epoch: 185, Accuracy: 0.7589447498321533, Loss: 0.5632100701332092\n",
      "Sample complexity: 5608768, Epoch: 186, Accuracy: 0.7589754462242126, Loss: 0.5631456971168518\n",
      "Sample complexity: 5640636, Epoch: 188, Accuracy: 0.758914053440094, Loss: 0.5631906390190125\n",
      "Sample complexity: 5672504, Epoch: 189, Accuracy: 0.758914053440094, Loss: 0.5631110668182373\n",
      "Sample complexity: 5704372, Epoch: 190, Accuracy: 0.758914053440094, Loss: 0.5629057884216309\n",
      "Sample complexity: 5736240, Epoch: 191, Accuracy: 0.7588833570480347, Loss: 0.5629354119300842\n",
      "Sample complexity: 5768108, Epoch: 192, Accuracy: 0.7588526010513306, Loss: 0.5628848075866699\n",
      "Sample complexity: 5799976, Epoch: 193, Accuracy: 0.7588526010513306, Loss: 0.5628804564476013\n",
      "Sample complexity: 5831844, Epoch: 194, Accuracy: 0.7588526010513306, Loss: 0.5627470016479492\n",
      "Sample complexity: 5863712, Epoch: 195, Accuracy: 0.7588833570480347, Loss: 0.5626670122146606\n",
      "Sample complexity: 5895580, Epoch: 196, Accuracy: 0.758914053440094, Loss: 0.5625126361846924\n",
      "Sample complexity: 5927448, Epoch: 197, Accuracy: 0.7589447498321533, Loss: 0.5624815821647644\n",
      "Sample complexity: 5959316, Epoch: 198, Accuracy: 0.7589754462242126, Loss: 0.5624808073043823\n",
      "Sample complexity: 5991184, Epoch: 199, Accuracy: 0.7589754462242126, Loss: 0.562477707862854\n",
      "Sample complexity: 6023052, Epoch: 200, Accuracy: 0.7590062022209167, Loss: 0.5624238848686218\n",
      "Sample complexity: 6054920, Epoch: 201, Accuracy: 0.7590368986129761, Loss: 0.5623645782470703\n",
      "Sample complexity: 6086788, Epoch: 202, Accuracy: 0.7590675950050354, Loss: 0.5623949766159058\n",
      "Sample complexity: 6118656, Epoch: 203, Accuracy: 0.7590675950050354, Loss: 0.5622801780700684\n",
      "Sample complexity: 6150524, Epoch: 205, Accuracy: 0.7590675950050354, Loss: 0.562201201915741\n",
      "Sample complexity: 6182392, Epoch: 206, Accuracy: 0.7590982913970947, Loss: 0.5621185302734375\n",
      "Sample complexity: 6214260, Epoch: 207, Accuracy: 0.7590368986129761, Loss: 0.5620277523994446\n",
      "Sample complexity: 6246128, Epoch: 208, Accuracy: 0.7590368986129761, Loss: 0.5618851184844971\n",
      "Sample complexity: 6277996, Epoch: 209, Accuracy: 0.7590368986129761, Loss: 0.5618093013763428\n",
      "Sample complexity: 6309864, Epoch: 210, Accuracy: 0.7590368986129761, Loss: 0.5617371797561646\n",
      "Sample complexity: 6341732, Epoch: 211, Accuracy: 0.7590368986129761, Loss: 0.5616647601127625\n",
      "Sample complexity: 6373600, Epoch: 212, Accuracy: 0.7590368986129761, Loss: 0.5617240071296692\n",
      "Sample complexity: 6405468, Epoch: 213, Accuracy: 0.7590982913970947, Loss: 0.5615451335906982\n",
      "Sample complexity: 6437336, Epoch: 214, Accuracy: 0.7590982913970947, Loss: 0.5614269375801086\n",
      "Sample complexity: 6469204, Epoch: 215, Accuracy: 0.7590675950050354, Loss: 0.5614265203475952\n",
      "Sample complexity: 6501072, Epoch: 216, Accuracy: 0.7590982913970947, Loss: 0.561428964138031\n",
      "Sample complexity: 6532940, Epoch: 217, Accuracy: 0.7590982913970947, Loss: 0.5613601207733154\n",
      "Sample complexity: 6564808, Epoch: 218, Accuracy: 0.7590982913970947, Loss: 0.5613985657691956\n",
      "Sample complexity: 6596676, Epoch: 219, Accuracy: 0.7590982913970947, Loss: 0.5612842440605164\n",
      "Sample complexity: 6628544, Epoch: 220, Accuracy: 0.7591290473937988, Loss: 0.5611519813537598\n",
      "Sample complexity: 6660412, Epoch: 222, Accuracy: 0.7591290473937988, Loss: 0.5612949728965759\n",
      "Sample complexity: 6692280, Epoch: 223, Accuracy: 0.7591290473937988, Loss: 0.5613279938697815\n",
      "Sample complexity: 6724148, Epoch: 224, Accuracy: 0.7591290473937988, Loss: 0.5613958239555359\n",
      "Sample complexity: 6756016, Epoch: 225, Accuracy: 0.7591290473937988, Loss: 0.5612917542457581\n",
      "Sample complexity: 6787884, Epoch: 226, Accuracy: 0.7591597437858582, Loss: 0.5612403750419617\n",
      "Sample complexity: 6819752, Epoch: 227, Accuracy: 0.7591597437858582, Loss: 0.5610606074333191\n",
      "Sample complexity: 6851620, Epoch: 228, Accuracy: 0.7591597437858582, Loss: 0.5608873963356018\n",
      "Sample complexity: 6883488, Epoch: 229, Accuracy: 0.7591597437858582, Loss: 0.560876727104187\n",
      "Sample complexity: 6915356, Epoch: 230, Accuracy: 0.7591597437858582, Loss: 0.5609641671180725\n",
      "Sample complexity: 6947224, Epoch: 231, Accuracy: 0.7591597437858582, Loss: 0.5608745813369751\n",
      "Sample complexity: 6979092, Epoch: 232, Accuracy: 0.7591597437858582, Loss: 0.560978889465332\n",
      "Sample complexity: 7010960, Epoch: 233, Accuracy: 0.7591597437858582, Loss: 0.5609560608863831\n",
      "Sample complexity: 7042828, Epoch: 234, Accuracy: 0.7591597437858582, Loss: 0.5609731078147888\n",
      "Sample complexity: 7074696, Epoch: 235, Accuracy: 0.7591597437858582, Loss: 0.5608170628547668\n",
      "Sample complexity: 7106564, Epoch: 236, Accuracy: 0.7591904401779175, Loss: 0.5608608722686768\n",
      "Sample complexity: 7138432, Epoch: 237, Accuracy: 0.7591904401779175, Loss: 0.56068354845047\n",
      "Sample complexity: 7170300, Epoch: 239, Accuracy: 0.7591904401779175, Loss: 0.5605677366256714\n",
      "Sample complexity: 7202168, Epoch: 240, Accuracy: 0.7591904401779175, Loss: 0.5604456663131714\n",
      "Sample complexity: 7234036, Epoch: 241, Accuracy: 0.7591904401779175, Loss: 0.560447633266449\n",
      "Sample complexity: 7265904, Epoch: 242, Accuracy: 0.7591904401779175, Loss: 0.5604768395423889\n",
      "Sample complexity: 7297772, Epoch: 243, Accuracy: 0.7591904401779175, Loss: 0.5604261755943298\n",
      "Sample complexity: 7329640, Epoch: 244, Accuracy: 0.7591597437858582, Loss: 0.5603004097938538\n",
      "Sample complexity: 7361508, Epoch: 245, Accuracy: 0.7591597437858582, Loss: 0.5602735280990601\n",
      "Sample complexity: 7393376, Epoch: 246, Accuracy: 0.7591597437858582, Loss: 0.5603201389312744\n",
      "Sample complexity: 7425244, Epoch: 247, Accuracy: 0.7591597437858582, Loss: 0.5602705478668213\n",
      "Sample complexity: 7457112, Epoch: 248, Accuracy: 0.7591597437858582, Loss: 0.5602829456329346\n",
      "Sample complexity: 7488980, Epoch: 249, Accuracy: 0.7591597437858582, Loss: 0.5601908564567566\n",
      "Sample complexity: 7520848, Epoch: 250, Accuracy: 0.7591597437858582, Loss: 0.5600989460945129\n",
      "Sample complexity: 7552716, Epoch: 251, Accuracy: 0.7591597437858582, Loss: 0.5600267648696899\n",
      "Sample complexity: 7584584, Epoch: 252, Accuracy: 0.7591597437858582, Loss: 0.5599513649940491\n",
      "Sample complexity: 7616452, Epoch: 253, Accuracy: 0.7591904401779175, Loss: 0.5598157644271851\n",
      "Sample complexity: 7648320, Epoch: 254, Accuracy: 0.7591904401779175, Loss: 0.5598354935646057\n",
      "Sample complexity: 7680188, Epoch: 256, Accuracy: 0.7591904401779175, Loss: 0.5598158240318298\n",
      "Sample complexity: 7712056, Epoch: 257, Accuracy: 0.7591904401779175, Loss: 0.559819221496582\n",
      "Sample complexity: 7743924, Epoch: 258, Accuracy: 0.7591904401779175, Loss: 0.5598945617675781\n",
      "Sample complexity: 7775792, Epoch: 259, Accuracy: 0.7591904401779175, Loss: 0.5598983764648438\n",
      "Sample complexity: 7807660, Epoch: 260, Accuracy: 0.7591904401779175, Loss: 0.5598266124725342\n",
      "Sample complexity: 7839528, Epoch: 261, Accuracy: 0.7591904401779175, Loss: 0.5597485303878784\n",
      "Sample complexity: 7871396, Epoch: 262, Accuracy: 0.7591904401779175, Loss: 0.559783935546875\n",
      "Sample complexity: 7903264, Epoch: 263, Accuracy: 0.7591904401779175, Loss: 0.5597478151321411\n",
      "Sample complexity: 7935132, Epoch: 264, Accuracy: 0.7591904401779175, Loss: 0.5597965121269226\n",
      "Sample complexity: 7967000, Epoch: 265, Accuracy: 0.7591904401779175, Loss: 0.5597212314605713\n",
      "Sample complexity: 7998868, Epoch: 266, Accuracy: 0.7591904401779175, Loss: 0.5596424341201782\n",
      "Sample complexity: 8030736, Epoch: 267, Accuracy: 0.7591904401779175, Loss: 0.5594823360443115\n",
      "Sample complexity: 8062604, Epoch: 268, Accuracy: 0.7591904401779175, Loss: 0.559391975402832\n",
      "Sample complexity: 8094472, Epoch: 269, Accuracy: 0.7591904401779175, Loss: 0.5594027042388916\n",
      "Sample complexity: 8126340, Epoch: 270, Accuracy: 0.7591904401779175, Loss: 0.559453547000885\n",
      "Sample complexity: 8158208, Epoch: 271, Accuracy: 0.7591904401779175, Loss: 0.559512734413147\n",
      "Sample complexity: 8190076, Epoch: 273, Accuracy: 0.7591904401779175, Loss: 0.5594894886016846\n",
      "Sample complexity: 8221944, Epoch: 274, Accuracy: 0.7591904401779175, Loss: 0.5595089197158813\n",
      "Sample complexity: 8253812, Epoch: 275, Accuracy: 0.7591904401779175, Loss: 0.5594900250434875\n",
      "Sample complexity: 8285680, Epoch: 276, Accuracy: 0.7591904401779175, Loss: 0.559492826461792\n",
      "Sample complexity: 8317548, Epoch: 277, Accuracy: 0.7591904401779175, Loss: 0.5595803260803223\n",
      "Sample complexity: 8349416, Epoch: 278, Accuracy: 0.7591904401779175, Loss: 0.5595189929008484\n",
      "Sample complexity: 8381284, Epoch: 279, Accuracy: 0.7591904401779175, Loss: 0.5595324635505676\n",
      "Sample complexity: 8413152, Epoch: 280, Accuracy: 0.7591904401779175, Loss: 0.5594751238822937\n",
      "Sample complexity: 8445020, Epoch: 281, Accuracy: 0.7591904401779175, Loss: 0.559508740901947\n",
      "Sample complexity: 8476888, Epoch: 282, Accuracy: 0.7591904401779175, Loss: 0.5595301389694214\n",
      "Sample complexity: 8508756, Epoch: 283, Accuracy: 0.7591904401779175, Loss: 0.5593854784965515\n",
      "Sample complexity: 8540624, Epoch: 284, Accuracy: 0.7591904401779175, Loss: 0.5592759251594543\n",
      "Sample complexity: 8572492, Epoch: 285, Accuracy: 0.7591904401779175, Loss: 0.5592960119247437\n",
      "Sample complexity: 8604360, Epoch: 286, Accuracy: 0.7591904401779175, Loss: 0.5592895746231079\n",
      "Sample complexity: 8636228, Epoch: 287, Accuracy: 0.7591904401779175, Loss: 0.5592965483665466\n",
      "Sample complexity: 8668096, Epoch: 288, Accuracy: 0.7591904401779175, Loss: 0.5592674612998962\n",
      "Sample complexity: 8699964, Epoch: 289, Accuracy: 0.7591904401779175, Loss: 0.5592315793037415\n",
      "Sample complexity: 8731832, Epoch: 291, Accuracy: 0.7591904401779175, Loss: 0.5592206716537476\n",
      "Sample complexity: 8763700, Epoch: 292, Accuracy: 0.7591904401779175, Loss: 0.5591609477996826\n",
      "Sample complexity: 8795568, Epoch: 293, Accuracy: 0.7591904401779175, Loss: 0.5591146945953369\n",
      "Sample complexity: 8827436, Epoch: 294, Accuracy: 0.7591904401779175, Loss: 0.5591974258422852\n",
      "Sample complexity: 8859304, Epoch: 295, Accuracy: 0.7591904401779175, Loss: 0.559169352054596\n",
      "Sample complexity: 8891172, Epoch: 296, Accuracy: 0.7591904401779175, Loss: 0.5591387748718262\n",
      "Sample complexity: 8923040, Epoch: 297, Accuracy: 0.7591904401779175, Loss: 0.5592196583747864\n",
      "Sample complexity: 8954908, Epoch: 298, Accuracy: 0.7591904401779175, Loss: 0.5592437386512756\n",
      "Sample complexity: 8986776, Epoch: 299, Accuracy: 0.7591904401779175, Loss: 0.5592167973518372\n",
      "Sample complexity: 9018644, Epoch: 300, Accuracy: 0.7591904401779175, Loss: 0.5591593384742737\n",
      "Sample complexity: 9050512, Epoch: 301, Accuracy: 0.7591904401779175, Loss: 0.5591214895248413\n",
      "Sample complexity: 9082380, Epoch: 302, Accuracy: 0.7591904401779175, Loss: 0.5591853260993958\n",
      "Sample complexity: 9114248, Epoch: 303, Accuracy: 0.7591904401779175, Loss: 0.5591698288917542\n",
      "Sample complexity: 9146116, Epoch: 304, Accuracy: 0.7591904401779175, Loss: 0.5590813755989075\n",
      "Sample complexity: 9177984, Epoch: 305, Accuracy: 0.7591904401779175, Loss: 0.5591263771057129\n",
      "Sample complexity: 9209852, Epoch: 306, Accuracy: 0.7591904401779175, Loss: 0.5591811537742615\n",
      "Sample complexity: 9241720, Epoch: 308, Accuracy: 0.7591904401779175, Loss: 0.5591520071029663\n",
      "Sample complexity: 9273588, Epoch: 309, Accuracy: 0.7591904401779175, Loss: 0.5591354966163635\n",
      "Sample complexity: 9305456, Epoch: 310, Accuracy: 0.7591904401779175, Loss: 0.5591086149215698\n",
      "Sample complexity: 9337324, Epoch: 311, Accuracy: 0.7591904401779175, Loss: 0.5590411424636841\n",
      "Sample complexity: 9369192, Epoch: 312, Accuracy: 0.7591904401779175, Loss: 0.5589591264724731\n",
      "Sample complexity: 9401060, Epoch: 313, Accuracy: 0.7591904401779175, Loss: 0.5589690208435059\n",
      "Sample complexity: 9432928, Epoch: 314, Accuracy: 0.7591904401779175, Loss: 0.5590247511863708\n",
      "Sample complexity: 9464796, Epoch: 315, Accuracy: 0.7591904401779175, Loss: 0.5591497421264648\n",
      "Sample complexity: 9496664, Epoch: 316, Accuracy: 0.7591904401779175, Loss: 0.5590291619300842\n",
      "Sample complexity: 9528532, Epoch: 317, Accuracy: 0.7591904401779175, Loss: 0.559087872505188\n",
      "Sample complexity: 9560400, Epoch: 318, Accuracy: 0.7591904401779175, Loss: 0.5590541958808899\n",
      "Sample complexity: 9592268, Epoch: 319, Accuracy: 0.7591904401779175, Loss: 0.5589832663536072\n",
      "Sample complexity: 9624136, Epoch: 320, Accuracy: 0.7591904401779175, Loss: 0.5591695308685303\n",
      "Sample complexity: 9656004, Epoch: 321, Accuracy: 0.7591904401779175, Loss: 0.5591171383857727\n",
      "Sample complexity: 9687872, Epoch: 322, Accuracy: 0.7591904401779175, Loss: 0.5590784549713135\n",
      "Sample complexity: 9719740, Epoch: 323, Accuracy: 0.7591904401779175, Loss: 0.5591344237327576\n",
      "Sample complexity: 9751608, Epoch: 325, Accuracy: 0.7591904401779175, Loss: 0.5591879487037659\n",
      "Sample complexity: 9783476, Epoch: 326, Accuracy: 0.7591904401779175, Loss: 0.5592560172080994\n",
      "Sample complexity: 9815344, Epoch: 327, Accuracy: 0.7591904401779175, Loss: 0.559208869934082\n",
      "Sample complexity: 9847212, Epoch: 328, Accuracy: 0.7591904401779175, Loss: 0.5591641664505005\n",
      "Sample complexity: 9879080, Epoch: 329, Accuracy: 0.7591904401779175, Loss: 0.5592088103294373\n",
      "Sample complexity: 9910948, Epoch: 330, Accuracy: 0.7591904401779175, Loss: 0.5591386556625366\n",
      "Sample complexity: 9942816, Epoch: 331, Accuracy: 0.7591904401779175, Loss: 0.559002697467804\n",
      "Sample complexity: 9974684, Epoch: 332, Accuracy: 0.7591904401779175, Loss: 0.5589684247970581\n",
      "Sample complexity: 10006552, Epoch: 333, Accuracy: 0.7591904401779175, Loss: 0.5590681433677673\n",
      "Sample complexity: 10038420, Epoch: 334, Accuracy: 0.7591904401779175, Loss: 0.5590422749519348\n",
      "Sample complexity: 10070288, Epoch: 335, Accuracy: 0.7591904401779175, Loss: 0.5589677691459656\n",
      "Sample complexity: 10102156, Epoch: 336, Accuracy: 0.7591904401779175, Loss: 0.5589467287063599\n",
      "Sample complexity: 10134024, Epoch: 337, Accuracy: 0.7591904401779175, Loss: 0.5589459538459778\n",
      "Sample complexity: 10165892, Epoch: 338, Accuracy: 0.7591904401779175, Loss: 0.5589176416397095\n",
      "Sample complexity: 10197760, Epoch: 339, Accuracy: 0.7591904401779175, Loss: 0.5589573383331299\n",
      "Sample complexity: 10229628, Epoch: 340, Accuracy: 0.7591904401779175, Loss: 0.5589500665664673\n",
      "Sample complexity: 10261496, Epoch: 342, Accuracy: 0.7591904401779175, Loss: 0.5590302348136902\n",
      "Sample complexity: 10293364, Epoch: 343, Accuracy: 0.7591904401779175, Loss: 0.5590436458587646\n",
      "Sample complexity: 10325232, Epoch: 344, Accuracy: 0.7591904401779175, Loss: 0.5589613318443298\n",
      "Sample complexity: 10357100, Epoch: 345, Accuracy: 0.7591904401779175, Loss: 0.5590695738792419\n",
      "Sample complexity: 10388968, Epoch: 346, Accuracy: 0.7591904401779175, Loss: 0.5589991211891174\n",
      "Sample complexity: 10420836, Epoch: 347, Accuracy: 0.7591904401779175, Loss: 0.5590680241584778\n",
      "Sample complexity: 10452704, Epoch: 348, Accuracy: 0.7591904401779175, Loss: 0.5591378211975098\n",
      "Sample complexity: 10484572, Epoch: 349, Accuracy: 0.7591904401779175, Loss: 0.5591370463371277\n",
      "Sample complexity: 10516440, Epoch: 350, Accuracy: 0.7591904401779175, Loss: 0.5591353178024292\n",
      "Sample complexity: 10548308, Epoch: 351, Accuracy: 0.7591904401779175, Loss: 0.5590991973876953\n",
      "Sample complexity: 10580176, Epoch: 352, Accuracy: 0.7591904401779175, Loss: 0.5591180920600891\n",
      "Sample complexity: 10612044, Epoch: 353, Accuracy: 0.7591904401779175, Loss: 0.559184193611145\n",
      "Sample complexity: 10643912, Epoch: 354, Accuracy: 0.7591904401779175, Loss: 0.5591780543327332\n",
      "Sample complexity: 10675780, Epoch: 355, Accuracy: 0.7591904401779175, Loss: 0.5591240525245667\n",
      "Sample complexity: 10707648, Epoch: 356, Accuracy: 0.7591904401779175, Loss: 0.559229850769043\n",
      "Sample complexity: 10739516, Epoch: 357, Accuracy: 0.7591904401779175, Loss: 0.5592496991157532\n",
      "Sample complexity: 10771384, Epoch: 359, Accuracy: 0.7591904401779175, Loss: 0.5594083666801453\n",
      "Sample complexity: 10803252, Epoch: 360, Accuracy: 0.7591904401779175, Loss: 0.5593905448913574\n",
      "Sample complexity: 10835120, Epoch: 361, Accuracy: 0.7591904401779175, Loss: 0.5594195127487183\n",
      "Sample complexity: 10866988, Epoch: 362, Accuracy: 0.7591904401779175, Loss: 0.5594021677970886\n",
      "Sample complexity: 10898856, Epoch: 363, Accuracy: 0.7591904401779175, Loss: 0.5592787265777588\n",
      "Sample complexity: 10930724, Epoch: 364, Accuracy: 0.7591904401779175, Loss: 0.5592076778411865\n",
      "Sample complexity: 10962592, Epoch: 365, Accuracy: 0.7591904401779175, Loss: 0.5591779947280884\n",
      "Sample complexity: 10994460, Epoch: 366, Accuracy: 0.7591904401779175, Loss: 0.5592284798622131\n",
      "Sample complexity: 11026328, Epoch: 367, Accuracy: 0.7591904401779175, Loss: 0.5592983961105347\n",
      "Sample complexity: 11058196, Epoch: 368, Accuracy: 0.7591904401779175, Loss: 0.5594096183776855\n",
      "Sample complexity: 11090064, Epoch: 369, Accuracy: 0.7591904401779175, Loss: 0.5593305230140686\n",
      "Sample complexity: 11121932, Epoch: 370, Accuracy: 0.7591904401779175, Loss: 0.5593031644821167\n",
      "Sample complexity: 11153800, Epoch: 371, Accuracy: 0.7591904401779175, Loss: 0.5593539476394653\n",
      "Sample complexity: 11185668, Epoch: 372, Accuracy: 0.7591904401779175, Loss: 0.5594099164009094\n",
      "Sample complexity: 11217536, Epoch: 373, Accuracy: 0.7591904401779175, Loss: 0.5595795512199402\n",
      "Sample complexity: 11249404, Epoch: 374, Accuracy: 0.7591904401779175, Loss: 0.5594687461853027\n",
      "Sample complexity: 11281272, Epoch: 376, Accuracy: 0.7591904401779175, Loss: 0.5594639778137207\n",
      "Sample complexity: 11313140, Epoch: 377, Accuracy: 0.7591904401779175, Loss: 0.5595265030860901\n",
      "Sample complexity: 11345008, Epoch: 378, Accuracy: 0.7591904401779175, Loss: 0.5594902038574219\n",
      "Sample complexity: 11376876, Epoch: 379, Accuracy: 0.7591904401779175, Loss: 0.5595337748527527\n",
      "Sample complexity: 11408744, Epoch: 380, Accuracy: 0.7591904401779175, Loss: 0.5595885515213013\n",
      "Sample complexity: 11440612, Epoch: 381, Accuracy: 0.7591904401779175, Loss: 0.5594738721847534\n",
      "Sample complexity: 11472480, Epoch: 382, Accuracy: 0.7591904401779175, Loss: 0.5595746636390686\n",
      "Sample complexity: 11504348, Epoch: 383, Accuracy: 0.7591904401779175, Loss: 0.5596155524253845\n",
      "Sample complexity: 11536216, Epoch: 384, Accuracy: 0.7591904401779175, Loss: 0.5595794916152954\n",
      "Sample complexity: 11568084, Epoch: 385, Accuracy: 0.7591904401779175, Loss: 0.5595827698707581\n",
      "Sample complexity: 11599952, Epoch: 386, Accuracy: 0.7591904401779175, Loss: 0.5595653653144836\n",
      "Sample complexity: 11631820, Epoch: 387, Accuracy: 0.7591904401779175, Loss: 0.5595930814743042\n",
      "Sample complexity: 11663688, Epoch: 388, Accuracy: 0.7591904401779175, Loss: 0.5595784187316895\n",
      "Sample complexity: 11695556, Epoch: 389, Accuracy: 0.7591904401779175, Loss: 0.5595607757568359\n",
      "Sample complexity: 11727424, Epoch: 390, Accuracy: 0.7591904401779175, Loss: 0.5596917867660522\n",
      "Sample complexity: 11759292, Epoch: 391, Accuracy: 0.7591904401779175, Loss: 0.5597907900810242\n",
      "Sample complexity: 11791160, Epoch: 393, Accuracy: 0.7591904401779175, Loss: 0.5598659515380859\n",
      "Sample complexity: 11823028, Epoch: 394, Accuracy: 0.7591904401779175, Loss: 0.5598022937774658\n",
      "Sample complexity: 11854896, Epoch: 395, Accuracy: 0.7591904401779175, Loss: 0.5598187446594238\n",
      "Sample complexity: 11886764, Epoch: 396, Accuracy: 0.7591904401779175, Loss: 0.5597901940345764\n",
      "Sample complexity: 11918632, Epoch: 397, Accuracy: 0.7591904401779175, Loss: 0.5597862601280212\n",
      "Sample complexity: 11950500, Epoch: 398, Accuracy: 0.7591904401779175, Loss: 0.5597395896911621\n",
      "Sample complexity: 11982368, Epoch: 399, Accuracy: 0.7591904401779175, Loss: 0.5598458051681519\n",
      "Sample complexity: 12014236, Epoch: 400, Accuracy: 0.7591904401779175, Loss: 0.55983966588974\n",
      "Sample complexity: 12046104, Epoch: 401, Accuracy: 0.7591904401779175, Loss: 0.559798002243042\n",
      "Sample complexity: 12077972, Epoch: 402, Accuracy: 0.7591904401779175, Loss: 0.5597901344299316\n",
      "Sample complexity: 12109840, Epoch: 403, Accuracy: 0.7591904401779175, Loss: 0.5598347783088684\n",
      "Sample complexity: 12141708, Epoch: 404, Accuracy: 0.7591904401779175, Loss: 0.5598647594451904\n",
      "Sample complexity: 12173576, Epoch: 405, Accuracy: 0.7591904401779175, Loss: 0.5599101781845093\n",
      "Sample complexity: 12205444, Epoch: 406, Accuracy: 0.7591904401779175, Loss: 0.5599951148033142\n",
      "Sample complexity: 12237312, Epoch: 407, Accuracy: 0.7591904401779175, Loss: 0.5600298047065735\n",
      "Sample complexity: 12269180, Epoch: 408, Accuracy: 0.7591904401779175, Loss: 0.5600988864898682\n",
      "Sample complexity: 12301048, Epoch: 410, Accuracy: 0.7591904401779175, Loss: 0.5601431727409363\n",
      "Sample complexity: 12332916, Epoch: 411, Accuracy: 0.7591904401779175, Loss: 0.5601740479469299\n",
      "Sample complexity: 12364784, Epoch: 412, Accuracy: 0.7591904401779175, Loss: 0.5602394938468933\n",
      "Sample complexity: 12396652, Epoch: 413, Accuracy: 0.7591904401779175, Loss: 0.5601792931556702\n",
      "Sample complexity: 12428520, Epoch: 414, Accuracy: 0.7591904401779175, Loss: 0.5603346228599548\n",
      "Sample complexity: 12460388, Epoch: 415, Accuracy: 0.7591904401779175, Loss: 0.5603850483894348\n",
      "Sample complexity: 12492256, Epoch: 416, Accuracy: 0.7591904401779175, Loss: 0.5604987740516663\n",
      "Sample complexity: 12524124, Epoch: 417, Accuracy: 0.7591904401779175, Loss: 0.5605304837226868\n",
      "Sample complexity: 12555992, Epoch: 418, Accuracy: 0.7591904401779175, Loss: 0.5605281591415405\n",
      "Sample complexity: 12587860, Epoch: 419, Accuracy: 0.7591904401779175, Loss: 0.56044602394104\n",
      "Sample complexity: 12619728, Epoch: 420, Accuracy: 0.7591904401779175, Loss: 0.5604696273803711\n",
      "Sample complexity: 12651596, Epoch: 421, Accuracy: 0.7591904401779175, Loss: 0.560532808303833\n",
      "Sample complexity: 12683464, Epoch: 422, Accuracy: 0.7591904401779175, Loss: 0.5605252385139465\n",
      "Sample complexity: 12715332, Epoch: 423, Accuracy: 0.7591904401779175, Loss: 0.5606008172035217\n",
      "Sample complexity: 12747200, Epoch: 424, Accuracy: 0.7591904401779175, Loss: 0.5605306625366211\n",
      "Sample complexity: 12779068, Epoch: 425, Accuracy: 0.7591904401779175, Loss: 0.5604875683784485\n",
      "Sample complexity: 12810936, Epoch: 427, Accuracy: 0.7591904401779175, Loss: 0.560461163520813\n",
      "Sample complexity: 12842804, Epoch: 428, Accuracy: 0.7591904401779175, Loss: 0.5605274438858032\n",
      "Sample complexity: 12874672, Epoch: 429, Accuracy: 0.7591904401779175, Loss: 0.5606029033660889\n",
      "Sample complexity: 12906540, Epoch: 430, Accuracy: 0.7591904401779175, Loss: 0.5606845617294312\n",
      "Sample complexity: 12938408, Epoch: 431, Accuracy: 0.7591904401779175, Loss: 0.5606716871261597\n",
      "Sample complexity: 12970276, Epoch: 432, Accuracy: 0.7591904401779175, Loss: 0.5606444478034973\n",
      "Sample complexity: 13002144, Epoch: 433, Accuracy: 0.7591904401779175, Loss: 0.5606659054756165\n",
      "Sample complexity: 13034012, Epoch: 434, Accuracy: 0.7591904401779175, Loss: 0.5605505108833313\n",
      "Sample complexity: 13065880, Epoch: 435, Accuracy: 0.7591904401779175, Loss: 0.5605497360229492\n",
      "Sample complexity: 13097748, Epoch: 436, Accuracy: 0.7591904401779175, Loss: 0.5606326460838318\n",
      "Sample complexity: 13129616, Epoch: 437, Accuracy: 0.7591904401779175, Loss: 0.5606606006622314\n",
      "Sample complexity: 13161484, Epoch: 438, Accuracy: 0.7591904401779175, Loss: 0.5607673525810242\n",
      "Sample complexity: 13193352, Epoch: 439, Accuracy: 0.7591904401779175, Loss: 0.5607104301452637\n",
      "Sample complexity: 13225220, Epoch: 440, Accuracy: 0.7591904401779175, Loss: 0.5607418417930603\n",
      "Sample complexity: 13257088, Epoch: 441, Accuracy: 0.7591904401779175, Loss: 0.5606532096862793\n",
      "Sample complexity: 13288956, Epoch: 442, Accuracy: 0.7591904401779175, Loss: 0.5606045722961426\n",
      "Sample complexity: 13320824, Epoch: 444, Accuracy: 0.7591904401779175, Loss: 0.56069415807724\n",
      "Sample complexity: 13352692, Epoch: 445, Accuracy: 0.7591904401779175, Loss: 0.5606316328048706\n",
      "Sample complexity: 13384560, Epoch: 446, Accuracy: 0.7591904401779175, Loss: 0.5606871247291565\n",
      "Sample complexity: 13416428, Epoch: 447, Accuracy: 0.7591904401779175, Loss: 0.5606799721717834\n",
      "Sample complexity: 13448296, Epoch: 448, Accuracy: 0.7591904401779175, Loss: 0.5606583952903748\n",
      "Sample complexity: 13480164, Epoch: 449, Accuracy: 0.7591904401779175, Loss: 0.5607151985168457\n",
      "Sample complexity: 13512032, Epoch: 450, Accuracy: 0.7591904401779175, Loss: 0.560716450214386\n",
      "Sample complexity: 13543900, Epoch: 451, Accuracy: 0.7591904401779175, Loss: 0.5608371496200562\n",
      "Sample complexity: 13575768, Epoch: 452, Accuracy: 0.7591904401779175, Loss: 0.5608441233634949\n",
      "Sample complexity: 13607636, Epoch: 453, Accuracy: 0.7591904401779175, Loss: 0.5608519315719604\n",
      "Sample complexity: 13639504, Epoch: 454, Accuracy: 0.7591904401779175, Loss: 0.5606972575187683\n",
      "Sample complexity: 13671372, Epoch: 455, Accuracy: 0.7591904401779175, Loss: 0.5607314705848694\n",
      "Sample complexity: 13703240, Epoch: 456, Accuracy: 0.7591904401779175, Loss: 0.5606632828712463\n",
      "Sample complexity: 13735108, Epoch: 457, Accuracy: 0.7591904401779175, Loss: 0.5606635808944702\n",
      "Sample complexity: 13766976, Epoch: 458, Accuracy: 0.7591904401779175, Loss: 0.5607408285140991\n",
      "Sample complexity: 13798844, Epoch: 459, Accuracy: 0.7591904401779175, Loss: 0.5608080625534058\n",
      "Sample complexity: 13830712, Epoch: 461, Accuracy: 0.7591904401779175, Loss: 0.5608499050140381\n",
      "Sample complexity: 13862580, Epoch: 462, Accuracy: 0.7591904401779175, Loss: 0.5608163475990295\n",
      "Sample complexity: 13894448, Epoch: 463, Accuracy: 0.7591904401779175, Loss: 0.5607472658157349\n",
      "Sample complexity: 13926316, Epoch: 464, Accuracy: 0.7591904401779175, Loss: 0.5607526302337646\n",
      "Sample complexity: 13958184, Epoch: 465, Accuracy: 0.7591904401779175, Loss: 0.5608576536178589\n",
      "Sample complexity: 13990052, Epoch: 466, Accuracy: 0.7591904401779175, Loss: 0.5609797835350037\n",
      "Sample complexity: 14021920, Epoch: 467, Accuracy: 0.7591904401779175, Loss: 0.5609920620918274\n",
      "Sample complexity: 14053788, Epoch: 468, Accuracy: 0.7591904401779175, Loss: 0.560836672782898\n",
      "Sample complexity: 14085656, Epoch: 469, Accuracy: 0.7591904401779175, Loss: 0.5608909130096436\n",
      "Sample complexity: 14117524, Epoch: 470, Accuracy: 0.7591904401779175, Loss: 0.560856819152832\n",
      "Sample complexity: 14149392, Epoch: 471, Accuracy: 0.7591904401779175, Loss: 0.5608318448066711\n",
      "Sample complexity: 14181260, Epoch: 472, Accuracy: 0.7591904401779175, Loss: 0.5608817338943481\n",
      "Sample complexity: 14213128, Epoch: 473, Accuracy: 0.7591904401779175, Loss: 0.5608533024787903\n",
      "Sample complexity: 14244996, Epoch: 474, Accuracy: 0.7591904401779175, Loss: 0.5609678626060486\n",
      "Sample complexity: 14276864, Epoch: 475, Accuracy: 0.7591904401779175, Loss: 0.5609726905822754\n",
      "Sample complexity: 14308732, Epoch: 476, Accuracy: 0.7591904401779175, Loss: 0.5610088109970093\n",
      "Sample complexity: 14340600, Epoch: 478, Accuracy: 0.7591904401779175, Loss: 0.5610498785972595\n",
      "Sample complexity: 14372468, Epoch: 479, Accuracy: 0.7591904401779175, Loss: 0.5611134171485901\n",
      "Sample complexity: 14404336, Epoch: 480, Accuracy: 0.7591904401779175, Loss: 0.5609219670295715\n",
      "Sample complexity: 14436204, Epoch: 481, Accuracy: 0.7591904401779175, Loss: 0.56087726354599\n",
      "Sample complexity: 14468072, Epoch: 482, Accuracy: 0.7591904401779175, Loss: 0.5608833432197571\n",
      "Sample complexity: 14499940, Epoch: 483, Accuracy: 0.7591904401779175, Loss: 0.5608729124069214\n",
      "Sample complexity: 14531808, Epoch: 484, Accuracy: 0.7591904401779175, Loss: 0.560875654220581\n",
      "Sample complexity: 14563676, Epoch: 485, Accuracy: 0.7591904401779175, Loss: 0.5608214139938354\n",
      "Sample complexity: 14595544, Epoch: 486, Accuracy: 0.7591904401779175, Loss: 0.5609084963798523\n",
      "Sample complexity: 14627412, Epoch: 487, Accuracy: 0.7591904401779175, Loss: 0.5608453154563904\n",
      "Sample complexity: 14659280, Epoch: 488, Accuracy: 0.7591904401779175, Loss: 0.5609289407730103\n",
      "Sample complexity: 14691148, Epoch: 489, Accuracy: 0.7591904401779175, Loss: 0.5608927607536316\n",
      "Sample complexity: 14723016, Epoch: 490, Accuracy: 0.7591904401779175, Loss: 0.5609394907951355\n",
      "Sample complexity: 14754884, Epoch: 491, Accuracy: 0.7591904401779175, Loss: 0.5609776973724365\n",
      "Sample complexity: 14786752, Epoch: 492, Accuracy: 0.7591904401779175, Loss: 0.5610620975494385\n",
      "Sample complexity: 14818620, Epoch: 493, Accuracy: 0.7591904401779175, Loss: 0.5609456300735474\n",
      "Sample complexity: 14850488, Epoch: 495, Accuracy: 0.7591904401779175, Loss: 0.5609051585197449\n",
      "Sample complexity: 14882356, Epoch: 496, Accuracy: 0.7591904401779175, Loss: 0.5607650279998779\n",
      "Sample complexity: 14914224, Epoch: 497, Accuracy: 0.7591904401779175, Loss: 0.5607767105102539\n",
      "Sample complexity: 14946092, Epoch: 498, Accuracy: 0.7591904401779175, Loss: 0.5606950521469116\n",
      "Sample complexity: 14977960, Epoch: 499, Accuracy: 0.7591904401779175, Loss: 0.560710072517395\n",
      "Sample complexity: 15009828, Epoch: 500, Accuracy: 0.7591904401779175, Loss: 0.5607395768165588\n",
      "Sample complexity: 15041696, Epoch: 501, Accuracy: 0.7591904401779175, Loss: 0.5607036352157593\n",
      "Sample complexity: 15073564, Epoch: 502, Accuracy: 0.7591904401779175, Loss: 0.5606595277786255\n",
      "Sample complexity: 15105432, Epoch: 503, Accuracy: 0.7591904401779175, Loss: 0.5607774257659912\n",
      "Sample complexity: 15137300, Epoch: 504, Accuracy: 0.7591904401779175, Loss: 0.560828685760498\n",
      "Sample complexity: 15169168, Epoch: 505, Accuracy: 0.7591904401779175, Loss: 0.5607878565788269\n",
      "Sample complexity: 15201036, Epoch: 506, Accuracy: 0.7591904401779175, Loss: 0.5608621835708618\n",
      "Sample complexity: 15232904, Epoch: 507, Accuracy: 0.7591904401779175, Loss: 0.5608817934989929\n",
      "Sample complexity: 15264772, Epoch: 508, Accuracy: 0.7591904401779175, Loss: 0.5607261061668396\n",
      "Sample complexity: 15296640, Epoch: 509, Accuracy: 0.7591904401779175, Loss: 0.5608500838279724\n",
      "Sample complexity: 15328508, Epoch: 510, Accuracy: 0.7591904401779175, Loss: 0.5610190629959106\n",
      "Sample complexity: 15360376, Epoch: 512, Accuracy: 0.7591904401779175, Loss: 0.5610384941101074\n",
      "Sample complexity: 15392244, Epoch: 513, Accuracy: 0.7591904401779175, Loss: 0.5610660910606384\n",
      "Sample complexity: 15424112, Epoch: 514, Accuracy: 0.7591904401779175, Loss: 0.5611136555671692\n",
      "Sample complexity: 15455980, Epoch: 515, Accuracy: 0.7591904401779175, Loss: 0.5610669851303101\n",
      "Sample complexity: 15487848, Epoch: 516, Accuracy: 0.7591904401779175, Loss: 0.5610312819480896\n",
      "Sample complexity: 15519716, Epoch: 517, Accuracy: 0.7591904401779175, Loss: 0.5609397292137146\n",
      "Sample complexity: 15551584, Epoch: 518, Accuracy: 0.7591904401779175, Loss: 0.5609180927276611\n",
      "Sample complexity: 15583452, Epoch: 519, Accuracy: 0.7591904401779175, Loss: 0.5609212517738342\n",
      "Sample complexity: 15615320, Epoch: 520, Accuracy: 0.7591904401779175, Loss: 0.5608005523681641\n",
      "Sample complexity: 15647188, Epoch: 521, Accuracy: 0.7591904401779175, Loss: 0.5608386993408203\n",
      "Sample complexity: 15679056, Epoch: 522, Accuracy: 0.7591904401779175, Loss: 0.5607389807701111\n",
      "Sample complexity: 15710924, Epoch: 523, Accuracy: 0.7591904401779175, Loss: 0.5607613325119019\n",
      "Sample complexity: 15742792, Epoch: 524, Accuracy: 0.7591904401779175, Loss: 0.5608552694320679\n",
      "Sample complexity: 15774660, Epoch: 525, Accuracy: 0.7591904401779175, Loss: 0.5607579946517944\n",
      "Sample complexity: 15806528, Epoch: 526, Accuracy: 0.7591904401779175, Loss: 0.5607221126556396\n",
      "Sample complexity: 15838396, Epoch: 527, Accuracy: 0.7591904401779175, Loss: 0.560745358467102\n",
      "Sample complexity: 15870264, Epoch: 529, Accuracy: 0.7591904401779175, Loss: 0.5608018040657043\n",
      "Sample complexity: 15902132, Epoch: 530, Accuracy: 0.7591904401779175, Loss: 0.5608563423156738\n",
      "Sample complexity: 15934000, Epoch: 531, Accuracy: 0.7591904401779175, Loss: 0.5608433485031128\n",
      "Sample complexity: 15965868, Epoch: 532, Accuracy: 0.7591904401779175, Loss: 0.5609005689620972\n",
      "Sample complexity: 15997736, Epoch: 533, Accuracy: 0.7591904401779175, Loss: 0.5608711242675781\n",
      "Sample complexity: 16029604, Epoch: 534, Accuracy: 0.7591904401779175, Loss: 0.5610618591308594\n",
      "Sample complexity: 16061472, Epoch: 535, Accuracy: 0.7591904401779175, Loss: 0.5611097812652588\n",
      "Sample complexity: 16093340, Epoch: 536, Accuracy: 0.7591904401779175, Loss: 0.5610604882240295\n",
      "Sample complexity: 16125208, Epoch: 537, Accuracy: 0.7591904401779175, Loss: 0.5610805749893188\n",
      "Sample complexity: 16157076, Epoch: 538, Accuracy: 0.7591904401779175, Loss: 0.5610332489013672\n",
      "Sample complexity: 16188944, Epoch: 539, Accuracy: 0.7591904401779175, Loss: 0.5612459182739258\n",
      "Sample complexity: 16220812, Epoch: 540, Accuracy: 0.7591904401779175, Loss: 0.5611861944198608\n",
      "Sample complexity: 16252680, Epoch: 541, Accuracy: 0.7591904401779175, Loss: 0.5610288381576538\n",
      "Sample complexity: 16284548, Epoch: 542, Accuracy: 0.7591904401779175, Loss: 0.5610080361366272\n",
      "Sample complexity: 16316416, Epoch: 543, Accuracy: 0.7591904401779175, Loss: 0.5610092282295227\n",
      "Sample complexity: 16348284, Epoch: 544, Accuracy: 0.7591904401779175, Loss: 0.5609748959541321\n",
      "Sample complexity: 16380152, Epoch: 546, Accuracy: 0.7591904401779175, Loss: 0.561001718044281\n",
      "Sample complexity: 16412020, Epoch: 547, Accuracy: 0.7591904401779175, Loss: 0.5610134601593018\n",
      "Sample complexity: 16443888, Epoch: 548, Accuracy: 0.7591904401779175, Loss: 0.5610821843147278\n",
      "Sample complexity: 16475756, Epoch: 549, Accuracy: 0.7591904401779175, Loss: 0.5610234141349792\n",
      "Sample complexity: 16507624, Epoch: 550, Accuracy: 0.7591904401779175, Loss: 0.5610852837562561\n",
      "Sample complexity: 16539492, Epoch: 551, Accuracy: 0.7591904401779175, Loss: 0.5611458420753479\n",
      "Sample complexity: 16571360, Epoch: 552, Accuracy: 0.7591904401779175, Loss: 0.561102032661438\n",
      "Sample complexity: 16603228, Epoch: 553, Accuracy: 0.7591904401779175, Loss: 0.5611359477043152\n",
      "Sample complexity: 16635096, Epoch: 554, Accuracy: 0.7591904401779175, Loss: 0.5610456466674805\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[12], line 11\u001b[0m\n\u001b[0;32m      7\u001b[0m sim_time \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m10\u001b[39m\n\u001b[0;32m      9\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01malg_SSAGDA_optimized\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m SSAGDA\n\u001b[1;32m---> 11\u001b[0m SSAGDA(train_set \u001b[38;5;241m=\u001b[39m train_set, data_name \u001b[38;5;241m=\u001b[39m data_name, p \u001b[38;5;241m=\u001b[39m p, tau_1 \u001b[38;5;241m=\u001b[39m tau_1, tau_2 \u001b[38;5;241m=\u001b[39m tau_2, beta \u001b[38;5;241m=\u001b[39m beta,\n\u001b[0;32m     12\u001b[0m         b \u001b[38;5;241m=\u001b[39m b, sim_time \u001b[38;5;241m=\u001b[39m sim_time, max_epoch \u001b[38;5;241m=\u001b[39m max_epoch, epoch_number \u001b[38;5;241m=\u001b[39m epoch_number, \n\u001b[0;32m     13\u001b[0m         is_show_result \u001b[38;5;241m=\u001b[39m is_show_result, is_save_data \u001b[38;5;241m=\u001b[39m is_save_data, is_save_grad_data \u001b[38;5;241m=\u001b[39m is_save_grad_data, device \u001b[38;5;241m=\u001b[39m device)\n",
      "File \u001b[1;32mc:\\Users\\sysa1\\Documents\\Research\\Optimization\\Research code\\Stochastic smoothed AGDA\\DRO\\alg_SSAGDA_optimized.py:53\u001b[0m, in \u001b[0;36mSSAGDA\u001b[1;34m(train_set, data_name, p, tau_1, tau_2, beta, b, sim_time, max_epoch, epoch_number, is_show_result, is_save_data, is_save_grad_data, device)\u001b[0m\n\u001b[0;32m     51\u001b[0m test1\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m     52\u001b[0m loss \u001b[38;5;241m=\u001b[39m test1\u001b[38;5;241m.\u001b[39mloss(test1(data), batch_index, target)\n\u001b[1;32m---> 53\u001b[0m loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[0;32m     55\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m (name, param), output_param \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(test1\u001b[38;5;241m.\u001b[39mnamed_parameters(), output):\n\u001b[0;32m     56\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvariable_y\u001b[39m\u001b[38;5;124m'\u001b[39m:\n",
      "File \u001b[1;32mc:\\Users\\sysa1\\anaconda3\\Lib\\site-packages\\torch\\_tensor.py:525\u001b[0m, in \u001b[0;36mTensor.backward\u001b[1;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[0;32m    515\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m    516\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[0;32m    517\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[0;32m    518\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m    523\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[0;32m    524\u001b[0m     )\n\u001b[1;32m--> 525\u001b[0m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mbackward(\n\u001b[0;32m    526\u001b[0m     \u001b[38;5;28mself\u001b[39m, gradient, retain_graph, create_graph, inputs\u001b[38;5;241m=\u001b[39minputs\n\u001b[0;32m    527\u001b[0m )\n",
      "File \u001b[1;32mc:\\Users\\sysa1\\anaconda3\\Lib\\site-packages\\torch\\autograd\\__init__.py:267\u001b[0m, in \u001b[0;36mbackward\u001b[1;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[0;32m    262\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[0;32m    264\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[0;32m    265\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[0;32m    266\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[1;32m--> 267\u001b[0m _engine_run_backward(\n\u001b[0;32m    268\u001b[0m     tensors,\n\u001b[0;32m    269\u001b[0m     grad_tensors_,\n\u001b[0;32m    270\u001b[0m     retain_graph,\n\u001b[0;32m    271\u001b[0m     create_graph,\n\u001b[0;32m    272\u001b[0m     inputs,\n\u001b[0;32m    273\u001b[0m     allow_unreachable\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m    274\u001b[0m     accumulate_grad\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[0;32m    275\u001b[0m )\n",
      "File \u001b[1;32mc:\\Users\\sysa1\\anaconda3\\Lib\\site-packages\\torch\\autograd\\graph.py:744\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[1;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[0;32m    742\u001b[0m     unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[0;32m    743\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m--> 744\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m Variable\u001b[38;5;241m.\u001b[39m_execution_engine\u001b[38;5;241m.\u001b[39mrun_backward(  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m    745\u001b[0m         t_outputs, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[0;32m    746\u001b[0m     )  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[0;32m    747\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[0;32m    748\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "is_show_result = True\n",
    "is_save_data = False\n",
    "is_save_grad_data = False\n",
    "\n",
    "p, tau_1, tau_2, beta, b = 160, 0.1, 0.002, 0.0001, 1028\n",
    "max_epoch, epoch_number = 1000, 30000\n",
    "sim_time = 5\n",
    "\n",
    "from alg_SSAGDA_optimized import SSAGDA\n",
    "\n",
    "SSAGDA(train_set = train_set, data_name = data_name, p = p, tau_1 = tau_1, tau_2 = tau_2, beta = beta,\n",
    "        b = b, sim_time = sim_time, max_epoch = max_epoch, epoch_number = epoch_number, \n",
    "        is_show_result = is_show_result, is_save_data = is_save_data, is_save_grad_data = is_save_grad_data, device = device)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
