{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18baa0b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from resnet_cifar import resnet20\n",
    "from train import tune_step_size\n",
    "from utils import create_exp\n",
    "from quant import clip_wrap\n",
    "import torch\n",
    "from torch.nn import CrossEntropyLoss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1db74ff2-25db-4f08-ba83-0b2208438fac",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_workers = 10 # define a number of workers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bcf660f-69c9-4602-aa17-b1ec38b61f49",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'resnet20' # give a name for a model to save"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37c01962-8925-420c-954d-40c8a5f8e8ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 32 # batch size \n",
    "beta = 0.1 # client momentum\n",
    "hbeta = 1 # server momentum parameter\n",
    "epochs = 5 # number of epochs\n",
    "seed = 1970 # fix random seed\n",
    "tau = 1e-3 # fix clipping threshold\n",
    "lr = 1e-2 # fix learning rate\n",
    "dataset = 'cifar10' # use 'mnist' or 'cifar10'\n",
    "\n",
    "# type of error feedback mechanism\n",
    "# ef = 'None'         # use ef = 'None' to run Clip-SGD\n",
    "# ef = 'Clip21_SGD'   # use ef = 'Clip21_SGD' to run Clip21-SGD\n",
    "ef = 'Clip21_SGD2M' # use ef = 'Clip21_SGD2M' to run Clip21-SGD2M\n",
    "# ef = 'ANorm'        # use ef = 'ANorm' to run alpha-NormEC-SGD\n",
    "\n",
    "method = 'Clip-SGD2M' # give a name for a method when saving\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# the quantities below should not be changed to run algorithms in non-private regime\n",
    "noise = 0 # DP noise variance; should be zero for non-private training\n",
    "DP = None # use DP = True to add Gaussian noise for privacy\n",
    "sch = None # use sch = True to use a scheduler\n",
    "weight_decay = 0 # weight decay is not supported at the moment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd72fff0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "exp = create_exp(name=model, dataset=dataset, net=resnet20, device=device, model_name=model,\n",
    "                 n_workers=n_workers, epochs=epochs, seed=seed, batch_size=bs, lrs=[lr], tau=tau, noise=noise, DP=DP,\n",
    "                 compression={'wrapper':False, 'compression':clip_wrap(h=tau)}, \n",
    "                 error_feedback=ef, criterion=CrossEntropyLoss(), \n",
    "                 master_compression=None, momentum=beta, beta=hbeta, weight_decay=weight_decay)\n",
    "\n",
    "name_to_save = f'{method}_{model}_lr{lr}_beta_{beta}_hbeta{hbeta}_tau{tau}_ef{ef}_sch{sch}_seed{seed}_bs{bs}_epochs{epochs}'\n",
    "best_lr, best_acc_lr = tune_step_size(exp, suffix=name_to_save, schedule=sch)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
