{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c11a04a6-45bc-4aef-bbae-0dd57f576dfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from library import load_datasets\n",
    "from library import metrics\n",
    "from library import misc\n",
    "from library import model_io\n",
    "from library import models\n",
    "from library import results_json\n",
    "from library import train\n",
    "from library import baseline_configs\n",
    "from library import configs\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "datasets = {'mnist': (configs.Dataset.MNIST, 10), \n",
    "            # 'fmnist': (configs.Dataset.FMNIST, 10), \n",
    "            # 'kmnist': (configs.Dataset.KMNIST, 10), \n",
    "            # 'qmnist': (configs.Dataset.QMNIST, 10), \n",
    "            # 'emnist_letters': (configs.Dataset.EMNIST_LETTERS, 26), \n",
    "            # 'emnist_balanced': (configs.Dataset.EMNIST_BALANCED, 47),\n",
    "            # 'cifar10': (configs.Dataset.CIFAR10, 10), \n",
    "            # 'cifar100': (configs.Dataset.CIFAR100, 100)\n",
    "}\n",
    "for key in datasets.keys():\n",
    "    for run in range(1):\n",
    "        config = baseline_configs.Get_MNIST_Config()\n",
    "    \n",
    "        ##### Experiment Specific Parameters #####\n",
    "        config.data_config.dataset = datasets[key][0]\n",
    "        config.model_config.last_layer_neurons = (config.model_config.num_neurons//datasets[key][1]) * datasets[key][1]\n",
    "        \n",
    "        config.experiment_config.experiment_name = f\"test\"\n",
    "\n",
    "        config.train_config.save_model_on = 'bin'\n",
    "        \n",
    "        config.train_config.extensive_eval = True\n",
    "        config.train_config.eval_freq = 4\n",
    "        config.train_config.learning_rate = 0.01\n",
    "        config.train_config.num_epochs = 100\n",
    "        config.train_config.extensive_eval_train = True\n",
    "        config.test_config.extensive_eval_test = True\n",
    "    \n",
    "        config.model_config.distanceLayer = False\n",
    "        config.model_config.use_mygroupsum = False\n",
    "        config.model_config.use_groupsum = True\n",
    "        config.model_config.full_ffn = False\n",
    "        config.model_config.use_ffn = False\n",
    "        config.model_config.use_ffbinary = False\n",
    "    \n",
    "        config.model_config.seed = run\n",
    "        #############################################\n",
    "        \n",
    "        model_config = config.model_config\n",
    "        print(model_config)\n",
    "        \n",
    "        misc.set_seed(config.model_config.seed)\n",
    "        \n",
    "        train_loader, validation_loader, test_loader, bin_loader, test_bin_loader = load_datasets.load_dataset(config)\n",
    "        network = models.create_model(config)\n",
    "        \n",
    "        loss_fn = torch.nn.CrossEntropyLoss()\n",
    "        \n",
    "        optimizer = torch.optim.Adam(network.parameters(), lr=config.train_config.learning_rate)\n",
    "        \n",
    "        if config.data_config.device == \"cuda\":\n",
    "            network = network.cuda()\n",
    "        \n",
    "        results = results_json.ResultsJSON(config)\n",
    "        train.train(model=network, \n",
    "            loss_fn=loss_fn, \n",
    "            optimizer=optimizer, \n",
    "            train_loader=train_loader, \n",
    "            validation_loader=validation_loader, \n",
    "            binarized_loader=bin_loader,\n",
    "            test_loader=test_loader,\n",
    "            test_loader_bin=test_bin_loader,\n",
    "            results=results, \n",
    "            config=config)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
