{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "edc3934c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from asyncore import write\n",
    "import datetime\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "from re import L\n",
    "import time\n",
    "from collections import OrderedDict\n",
    "from unittest import result\n",
    "from sqlalchemy import true\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm, trange\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from tensorboardX import SummaryWriter\n",
    "\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "import models\n",
    "from models.meta_sgd import MetaSGD\n",
    "import optimizers\n",
    "from config.configuration import *\n",
    "\n",
    "\n",
    "from util import enlist_transformation\n",
    "from data_generate.dataset import FewShotImageDataset\n",
    "from data_generate.sampler import SuppQueryBatchSampler\n",
    "from meta_train import meta_train\n",
    "from meta_test import meta_test\n",
    "import util"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "595514ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = '../config/1-shot/metasgd.json'\n",
    "# load config file\n",
    "jsonfile = open(str(config_path))\n",
    "config = json.loads(jsonfile.read())\n",
    "\n",
    "config['dataset_ls'] = config['dataset_ls'][:config['num_dataset_to_run']]\n",
    "config['classifier_args']['n_way'] = config['num_way']\n",
    "config['device'] = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7985f892",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MetaSGD(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0cef94a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import OrderedDict\n",
    "\n",
    "aaa = OrderedDict(model.named_parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cce67e51",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "odict_keys(['meta_learner.features.0.conv0.weight', 'meta_learner.features.0.conv0.bias', 'meta_learner.features.0.bn0.weight', 'meta_learner.features.0.bn0.bias', 'meta_learner.features.1.conv1.weight', 'meta_learner.features.1.conv1.bias', 'meta_learner.features.1.bn1.weight', 'meta_learner.features.1.bn1.bias', 'meta_learner.features.2.conv2.weight', 'meta_learner.features.2.conv2.bias', 'meta_learner.features.2.bn2.weight', 'meta_learner.features.2.bn2.bias', 'meta_learner.features.3.conv3.weight', 'meta_learner.features.3.conv3.bias', 'meta_learner.features.3.bn3.weight', 'meta_learner.features.3.bn3.bias', 'meta_learner.fc.weight', 'meta_learner.fc.bias'])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "aaa.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6ee33966",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'bn_args': {'track_running_stats': False}}\n",
      "32\n",
      "{'bn_args': {'track_running_stats': False}}\n"
     ]
    }
   ],
   "source": [
    "model = models.make(config['encoder'], config['encoder_args'],\n",
    "                        config['classifier'], config['classifier_args'], config['img_resize'], config['device'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "756231b2",
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "'Parameter' object has no attribute 'value'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "Input \u001b[0;32mIn [22]\u001b[0m, in \u001b[0;36m<cell line: 2>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m weight \u001b[38;5;241m=\u001b[39m OrderedDict(model\u001b[38;5;241m.\u001b[39mclassifier\u001b[38;5;241m.\u001b[39mnamed_parameters())\n\u001b[0;32m----> 2\u001b[0m \u001b[43mweight\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mlinear.weight\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue\u001b[49m\n",
      "\u001b[0;31mAttributeError\u001b[0m: 'Parameter' object has no attribute 'value'"
     ]
    }
   ],
   "source": [
    "weight = OrderedDict(model.classifier.named_parameters())\n",
    "weight['linear.weight']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "089c0cd0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 800])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weight['linear.weight'].size()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
