{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Example of the data\n",
      "tensor([0., 1., 0.,  ..., 1., 0., 0.], device='cuda:0')\n",
      "12678\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import copy\n",
    "from model import Model\n",
    "from Optimization_Method import projection_simplex_sort as pj\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "from dataclass import Creatdata\n",
    "\n",
    "import os\n",
    "\n",
    "\n",
    "# print(\"Current working directory:\", os.getcwd())\n",
    "\n",
    "# # Specify the new directory path\n",
    "# new_directory = 'C:/Users/sysa1/Documents/Research/Optimization/Research code/Stochastic smoothed AGDA/DRO'\n",
    "\n",
    "# # Change the current working directory\n",
    "# os.chdir(new_directory)\n",
    "\n",
    "#example of pickle\n",
    "# l = [1,2,3,4]\n",
    "# with open(\"test\", \"wb\") as fp:   #Pickling\n",
    "#     pickle.dump(l, fp)\n",
    "\n",
    "# with open(\"test\", \"rb\") as fp:   # Unpickling\n",
    "#     b = pickle.load(fp)\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "torch.cuda.get_device_name(0)\n",
    "\n",
    "data_name = 'sido0'\n",
    "\n",
    "is_create_data = False\n",
    "\n",
    "if is_create_data:\n",
    "    data_path = './data/'+ data_name +'.py'\n",
    "    print(data_path)\n",
    "    exec(open(data_path).read())\n",
    "else:\n",
    "    file_name = './data/' + data_name + '/' + data_name\n",
    "    with open(file_name, \"rb\") as fp:   # Unpickling\n",
    "        train_set = pickle.load(fp)\n",
    "\n",
    "train_set.data = train_set.data.to(device)\n",
    "train_set.targets = train_set.targets.to(device)\n",
    "\n",
    "print('Example of the data')\n",
    "print(train_set.data[0])\n",
    "print(len(train_set.data))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NVIDIA GeForce RTX 3050 6GB Laptop GPU\n",
      "0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\sysa1\\Documents\\Research\\Optimization\\Research code\\Stochastic smoothed AGDA\\DRO\\model_SSAGDA.py:10: UserWarning: nn.init.xavier_uniform is now deprecated in favor of nn.init.xavier_uniform_.\n",
      "  self.w = torch.nn.init.xavier_uniform(self.w)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n",
      "24\n",
      "25\n",
      "26\n",
      "27\n",
      "28\n",
      "29\n",
      "30\n",
      "31\n",
      "32\n",
      "33\n",
      "34\n",
      "35\n",
      "36\n",
      "37\n",
      "38\n",
      "39\n",
      "40\n",
      "41\n",
      "42\n",
      "43\n",
      "44\n",
      "45\n",
      "46\n",
      "47\n",
      "48\n",
      "49\n",
      "50\n",
      "51\n",
      "52\n",
      "53\n",
      "54\n",
      "55\n",
      "56\n",
      "57\n",
      "58\n",
      "59\n",
      "60\n",
      "61\n",
      "62\n",
      "63\n",
      "64\n",
      "65\n",
      "66\n",
      "67\n",
      "68\n",
      "69\n",
      "70\n",
      "71\n",
      "72\n",
      "73\n",
      "74\n",
      "75\n",
      "76\n",
      "77\n",
      "78\n",
      "79\n",
      "80\n",
      "81\n",
      "82\n",
      "83\n",
      "84\n",
      "85\n",
      "86\n",
      "87\n",
      "88\n",
      "89\n",
      "90\n",
      "91\n",
      "92\n",
      "93\n",
      "94\n",
      "95\n",
      "96\n",
      "97\n",
      "98\n",
      "99\n",
      "100\n",
      "101\n",
      "102\n",
      "103\n",
      "104\n",
      "105\n",
      "106\n",
      "107\n",
      "108\n",
      "109\n",
      "110\n",
      "111\n",
      "112\n",
      "113\n",
      "114\n",
      "115\n",
      "116\n",
      "117\n",
      "118\n",
      "119\n",
      "120\n",
      "121\n",
      "122\n",
      "123\n",
      "124\n",
      "125\n",
      "126\n",
      "127\n",
      "128\n",
      "129\n",
      "130\n",
      "131\n",
      "132\n",
      "133\n",
      "134\n",
      "135\n",
      "136\n",
      "137\n",
      "138\n",
      "139\n",
      "140\n",
      "141\n",
      "142\n",
      "143\n",
      "144\n",
      "145\n",
      "146\n",
      "147\n",
      "148\n",
      "149\n",
      "150\n",
      "151\n",
      "152\n",
      "153\n",
      "154\n",
      "155\n",
      "156\n",
      "157\n",
      "158\n",
      "159\n",
      "160\n",
      "161\n",
      "162\n",
      "163\n",
      "164\n",
      "165\n",
      "166\n",
      "167\n",
      "168\n",
      "169\n",
      "170\n",
      "171\n",
      "172\n",
      "173\n",
      "174\n",
      "175\n",
      "176\n",
      "177\n",
      "178\n",
      "179\n",
      "180\n",
      "181\n",
      "182\n",
      "183\n",
      "184\n",
      "185\n",
      "186\n",
      "187\n",
      "188\n",
      "189\n",
      "190\n",
      "191\n",
      "192\n",
      "193\n",
      "194\n",
      "195\n",
      "196\n",
      "197\n",
      "198\n",
      "199\n"
     ]
    }
   ],
   "source": [
    "is_show_result = False\n",
    "is_save_result = False\n",
    "is_save_grad_data = True\n",
    "\n",
    "p, tau_1, tau_2, beta, b = 160, 0.001, 0.0002, 0.00001, 1028\n",
    "max_epoch, epoch_number = 20, 12678\n",
    "sim_time =  200\n",
    "\n",
    "from alg_SSAGDA_optimized import SSAGDA\n",
    "\n",
    "SSAGDA(train_set = train_set, data_name = data_name, p = p, tau_1 = tau_1, tau_2 = tau_2, beta = beta,\n",
    "        b = b, sim_time = sim_time, max_epoch = max_epoch, epoch_number = epoch_number, \n",
    "        is_show_result = is_show_result, is_save_data = is_save_result, is_save_grad_data = is_save_grad_data, device = device)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
