{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "pregnant-tiger",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.7.1\n"
     ]
    }
   ],
   "source": [
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.transforms as transforms\n",
    "import copy\n",
    "from utils import set_seed, load_backbone_state_dict_only\n",
    "from utils import plotSamples\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "import sys\n",
    "from utils import set_seed\n",
    "import torch\n",
    "from tqdm.auto import tqdm\n",
    "from scipy.stats import invgamma\n",
    "from copy import deepcopy\n",
    "from IGML import ML_IG as estimator\n",
    "from Sampler import Sampler\n",
    "from test import evaluate, get_model_with_ml\n",
    "from dataset import Dataset\n",
    "import sys\n",
    "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "parliamentary-mongolia",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model.ProtoNet import ProtoNet\n",
    "from model.Bayesian import Bayesian\n",
    "from model.MAP import MAP\n",
    "from model.PGLR import PGLR\n",
    "from model.LR import LR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "closed-beast",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "dataset_name = 'M'\n",
    "models_paths = {}\n",
    "models_paths['O'] = {}\n",
    "models_paths['O'][\"Scratch\"]    = \"../experiments/omniglot/omniglot_random/check_point\"\n",
    "models_paths['O'][\"Pretrained\"] = \"../experiments/omniglot/omniglot_Pretrain_final/check_point\"\n",
    "models_paths['O'][LR]           = \"../experiments/omniglot/omniglot_MTLR_final/check_point\"\n",
    "models_paths['O'][PGLR]         = \"../experiments/omniglot/omniglot_PGLR_final/check_point\"\n",
    "models_paths['O'][ProtoNet]      = \"../experiments/omniglot/omniglot_Protonet_final/check_point\"\n",
    "models_paths['O'][MAP]          = \"../experiments/omniglot/omniglot_MAP_final/check_point\"\n",
    "models_paths['O'][Bayesian]     = \"../experiments/omniglot/omniglot_Bayesian_final/check_point\"\n",
    "\n",
    "\n",
    "models_paths['M'] = {}\n",
    "models_paths['M'][\"Scratch\"]    = \"../experiments/miniimagenet/imagenet_random/check_point\"\n",
    "models_paths['M'][\"Pretrained\"] = \"../experiments/miniimagenet/imagenet_pretrain_final/check_point\"\n",
    "models_paths['M'][LR]           = \"../experiments/miniimagenet/imagenet_MTLR_final/check_point\"\n",
    "models_paths['M'][PGLR]         = \"../experiments/miniimagenet/imagenet_PGLR_final/check_point\"\n",
    "models_paths['M'][ProtoNet]      = \"../experiments/miniimagenet/imagenet_protonet_final/check_point\"\n",
    "models_paths['M'][MAP]          = \"../experiments/miniimagenet/imagenet_map_final/check_point\"\n",
    "models_paths['M'][Bayesian]     = \"../experiments/miniimagenet/imagenet_bayesian_final/check_point\"\n",
    "\n",
    "lrs = {}\n",
    "lrs['M'] = {}\n",
    "lrs['M']['not_iid'] = {}\n",
    "lrs['M']['not_iid'][PGLR] = {}\n",
    "\n",
    "\n",
    "lrs['M']['not_iid'][PGLR][\"Pretrained\"] = 10.0\n",
    "lrs['M']['not_iid'][PGLR][ProtoNet]     = 1000.0\n",
    "lrs['M']['not_iid'][PGLR][Bayesian]     = 1000.0\n",
    "lrs['M']['not_iid'][PGLR][MAP]          = 1000.0\n",
    "lrs['M']['not_iid'][PGLR][PGLR]         =  1.0\n",
    "lrs['M']['not_iid'][PGLR][LR]           = 10.0\n",
    "lrs['M']['not_iid'][PGLR]['Scratch']    = 1.0\n",
    "\n",
    "lrs['M']['not_iid'][LR] = {}\n",
    "\n",
    "\n",
    "lrs['M']['not_iid'][LR][\"Pretrained\"] = 1e-5\n",
    "lrs['M']['not_iid'][LR][ProtoNet]     = 1e-6\n",
    "lrs['M']['not_iid'][LR][Bayesian]     = 1e-6\n",
    "lrs['M']['not_iid'][LR][MAP]          = 1e-6\n",
    "lrs['M']['not_iid'][LR][PGLR]         = 1e-6\n",
    "lrs['M']['not_iid'][LR][LR]           = 1e-5\n",
    "lrs['M']['not_iid'][LR]['Scratch']    = 1e-6\n",
    "\n",
    "lrs['O'] = {}\n",
    "lrs['O']['not_iid'] = {}\n",
    "lrs['O']['not_iid'][PGLR] = {}\n",
    "\n",
    "\n",
    "lrs['O']['not_iid'][PGLR][\"Pretrained\"] =   1.0\n",
    "lrs['O']['not_iid'][PGLR][ProtoNet]     =  10.0\n",
    "lrs['O']['not_iid'][PGLR][Bayesian]     = 100.0\n",
    "lrs['O']['not_iid'][PGLR][MAP]          = 100.0\n",
    "lrs['O']['not_iid'][PGLR][PGLR]         =   0.1\n",
    "lrs['O']['not_iid'][PGLR][LR]           =   1.0\n",
    "lrs['O']['not_iid'][PGLR]['Scratch']    =   10.0\n",
    "\n",
    "lrs['O']['not_iid'][LR] = {}\n",
    "\n",
    "\n",
    "lrs['O']['not_iid'][LR][\"Pretrained\"] = 1e-5\n",
    "lrs['O']['not_iid'][LR][ProtoNet]     = 1e-6\n",
    "lrs['O']['not_iid'][LR][Bayesian]     = 1e-5\n",
    "lrs['O']['not_iid'][LR][MAP]          = 1e-5\n",
    "lrs['O']['not_iid'][LR][PGLR]         = 1e-6\n",
    "lrs['O']['not_iid'][LR][LR]           = 1e-4\n",
    "lrs['O']['not_iid'][LR]['Scratch']    = 1e-6\n",
    "\n",
    "lrs['M']['iid'] = {}\n",
    "lrs['M']['iid'][PGLR] = {}\n",
    "lrs['M']['iid'][PGLR][\"Pretrained\"]  = lrs['M']['iid'][PGLR][\"Scratch\"] = 1.0\n",
    "lrs['M']['iid'][PGLR][PGLR] = 0.1\n",
    "lrs['M']['iid'][LR] = {}\n",
    "lrs['M']['iid'][LR][\"Pretrained\"] = lrs['M']['iid'][LR][\"Scratch\"] =  1e-3\n",
    "lrs['M']['iid'][LR][LR] = 1e-3\n",
    "\n",
    "\n",
    "lrs['O']['iid'] = {}\n",
    "lrs['O']['iid'][PGLR] = {}\n",
    "lrs['O']['iid'][PGLR][\"Pretrained\"] = lrs['O']['iid'][PGLR][\"Scratch\"] = 10.0\n",
    "lrs['O']['iid'][PGLR][PGLR] = 1.0\n",
    "\n",
    "lrs['O']['iid'][LR] = {}\n",
    "lrs['O']['iid'][LR][\"Pretrained\"] = lrs['O']['iid'][LR][\"Scratch\"] = 1e-3\n",
    "lrs['O']['iid'][LR][LR] = 1e-3\n",
    "\n",
    "\n",
    "Models_with_ML = {Bayesian, MAP}\n",
    "Models_with_LR = {LR, PGLR}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "compatible-choir",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bizarre-vegetable",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.omniglot.TrainParams import MetaTrainParams as MetaTrainParams_OOO\n",
    "from datasets.omniglot.TestParams  import MetaTestParams as MetaTestParams_OOO\n",
    "\n",
    "from datasets.miniimagenet.TrainParams import MetaTrainParams as MetaTrainParams_MMM\n",
    "from datasets.miniimagenet.TestParams  import MetaTestParams  as MetaTestParams_MMM\n",
    "\n",
    "def get_params_sampler_datasets (dataset_name):\n",
    "\n",
    "\n",
    "    if dataset_name == 'O':\n",
    "        params_test = MetaTestParams_OOO()\n",
    "        params_train = MetaTrainParams_OOO()\n",
    "    else:\n",
    "        params_test = MetaTestParams_MMM()\n",
    "        params_train = MetaTrainParams_MMM()\n",
    "        \n",
    "\n",
    "    params_train.meta_train_transforms = params_train.transforms_for_test\n",
    "\n",
    "    dataset_train = Dataset.get_dataset(params_train, params_train.meta_train_transforms)\n",
    "\n",
    "    dataset_test  = Dataset.get_dataset(params_test,  transform=params_test.meta_transforms)\n",
    "\n",
    "    if dataset_name == \"O\":\n",
    "        params_test.query_num_train_tasks = params_test.support_num_train_tasks = 600\n",
    "        params_test.support_inner_step = 15\n",
    "        params_test.query_train_inner_step = 5\n",
    "        params_test.meta_test_steps = 100\n",
    "    else:\n",
    "        params_test.query_num_train_tasks = params_test.support_num_train_tasks = 20\n",
    "        params_test.support_inner_step = 100\n",
    "        params_test.query_train_inner_step = 100\n",
    "        params_test.meta_test_steps = 100\n",
    "\n",
    "    return params_train, params_test, dataset_train, dataset_test\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "willing-precipitation",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "../datasets/miniimagenet/ train\n",
      "Total classes =  64\n",
      "../datasets/miniimagenet/ test\n",
      "Total classes =  20\n"
     ]
    }
   ],
   "source": [
    "params_train, params_test, dataset_train, dataset_test = get_params_sampler_datasets(dataset_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "digital-routine",
   "metadata": {},
   "outputs": [],
   "source": [
    "Source_Backbones   = [\"Scratch\", \"Pretrained\", LR, PGLR, ProtoNet, MAP, Bayesian]\n",
    "Destination_Models = [LR, PGLR, ProtoNet, MAP, Bayesian]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "illegal-bradford",
   "metadata": {},
   "outputs": [],
   "source": [
    "Source_Backbones   = [Bayesian]\n",
    "Destination_Models = [LR]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "operational-documentation",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "for Source_Backbone in tqdm (Source_Backbones, \"source backbone\"):\n",
    "    results[Source_Backbone] = {}\n",
    "    for Destination_Model in Destination_Models:\n",
    "        print (f\"Source Backbone:{Source_Backbone} , Destination_Model:{Destination_Model}\")\n",
    "\n",
    "        if Destination_Model in Models_with_LR:\n",
    "            lr = lrs[dataset_name]['not_iid'][Destination_Model][Source_Backbone]\n",
    "            print (f\"using lr{lr}\")\n",
    "            params_train.inner_lr = lr\n",
    "        \n",
    "        model = Destination_Model(params_train).to(device)\n",
    "        \n",
    "        model.back_bone.load_state_dict (torch.load (models_paths[dataset_name][Source_Backbone]))\n",
    "        \n",
    "        model = get_model_with_ml(model, params_train, dataset_train)\n",
    "\n",
    "        acc_mean, acc_std = evaluate(params_test, model, dataset_test, verbose=False)\n",
    "        print (acc_mean, acc_std)\n",
    "            \n",
    "        results[Source_Backbone][Destination_Model] = acc_mean, acc_std\n",
    "        \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "aquatic-waste",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scratch,<class 'model.LR.LR'>: (0.0044466666666666665, 0.0009276254033223157)\n",
      "Scratch,<class 'model.PGLR.PGLR'>: (0.3271066666666669, 0.008257371117842412)\n",
      "Scratch,<class 'model.ProtoNet.ProtoNet'>: (0.36240000000000044, 0.007625687582842032)\n",
      "Scratch,<class 'model.MAP.MAP'>: (0.38086666666666713, 0.007724707833381795)\n",
      "Scratch,<class 'model.Bayesian.Bayesian'>: (0.3761666666666671, 0.0077155686763840745)\n",
      "Pretrained,<class 'model.LR.LR'>: (0.25712, 0.006519291031665652)\n",
      "Pretrained,<class 'model.PGLR.PGLR'>: (0.6268666666666673, 0.008262095510354594)\n",
      "Pretrained,<class 'model.ProtoNet.ProtoNet'>: (0.7395000000000006, 0.00661253019316775)\n",
      "Pretrained,<class 'model.MAP.MAP'>: (0.7630800000000009, 0.00811433847512471)\n",
      "Pretrained,<class 'model.Bayesian.Bayesian'>: (0.7584133333333343, 0.00805413903254442)\n",
      "<class 'model.LR.LR'>,<class 'model.LR.LR'>: (0.6389600000000006, 0.006858454636432297)\n",
      "<class 'model.LR.LR'>,<class 'model.PGLR.PGLR'>: (0.7463600000000012, 0.005334099944903806)\n",
      "<class 'model.LR.LR'>,<class 'model.ProtoNet.ProtoNet'>: (0.8199866666666681, 0.005697937248591812)\n",
      "<class 'model.LR.LR'>,<class 'model.MAP.MAP'>: (0.8425133333333357, 0.0067941837871326296)\n",
      "<class 'model.LR.LR'>,<class 'model.Bayesian.Bayesian'>: (0.848346666666669, 0.006264879177694821)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.LR.LR'>: (0.3913333333333335, 0.007278583500776617)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.PGLR.PGLR'>: (0.7532400000000017, 0.007237568652523995)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.ProtoNet.ProtoNet'>: (0.8708266666666692, 0.005135601230625605)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.MAP.MAP'>: (0.8784733333333363, 0.005112985646589032)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.Bayesian.Bayesian'>: (0.8743466666666697, 0.004861165840961994)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.LR.LR'>: (0.019233333333333335, 0.0034501207708330078)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.PGLR.PGLR'>: (0.7960266666666681, 0.006671944577441102)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.ProtoNet.ProtoNet'>: (0.8371866666666685, 0.00593788589389288)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.MAP.MAP'>: (0.8357000000000021, 0.005223557748848631)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.Bayesian.Bayesian'>: (0.836900000000002, 0.0055318271042317626)\n",
      "<class 'model.MAP.MAP'>,<class 'model.LR.LR'>: (0.20021999999999998, 0.018681971107044493)\n",
      "<class 'model.MAP.MAP'>,<class 'model.PGLR.PGLR'>: (0.8272533333333351, 0.005567588745987591)\n",
      "<class 'model.MAP.MAP'>,<class 'model.ProtoNet.ProtoNet'>: (0.8640866666666692, 0.005428652380349019)\n",
      "<class 'model.MAP.MAP'>,<class 'model.MAP.MAP'>: (0.868773333333336, 0.005626303305803162)\n",
      "<class 'model.MAP.MAP'>,<class 'model.Bayesian.Bayesian'>: (0.8667066666666692, 0.005880529076726543)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.LR.LR'>: (0.19801333333333335, 0.0193119375862013)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.PGLR.PGLR'>: (0.8280733333333352, 0.006094912268077127)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.ProtoNet.ProtoNet'>: (0.8658266666666692, 0.005715569379627577)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.MAP.MAP'>: (0.8704000000000027, 0.005776581072649356)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.Bayesian.Bayesian'>: (0.8684066666666692, 0.005572268638335728)\n"
     ]
    }
   ],
   "source": [
    "for key in results.keys():\n",
    "    for key2 in results[key].keys():\n",
    "        print (f\"{key},{key2}:\", results[key][key2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "promising-senator",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Scratch,<class 'model.LR.LR'>: (0.05537000000000001, 0.008713386253346054)\n",
      "Scratch,<class 'model.PGLR.PGLR'>: (0.153375, 0.01234065942322371)\n",
      "Scratch,<class 'model.ProtoNet.ProtoNet'>: (0.1782, 0.007603946343840153)\n",
      "Scratch,<class 'model.MAP.MAP'>: (0.18903999999999999, 0.008182200193102097)\n",
      "Scratch,<class 'model.Bayesian.Bayesian'>: (0.18893000000000001, 0.008245610953713496)\n",
      "Pretrained,<class 'model.LR.LR'>: (0.32811999999999997, 0.03127268136888809)\n",
      "Pretrained,<class 'model.PGLR.PGLR'>: (0.5419400000000002, 0.009918487787964444)\n",
      "Pretrained,<class 'model.ProtoNet.ProtoNet'>: (0.5556350000000001, 0.010868039151567302)\n",
      "Pretrained,<class 'model.MAP.MAP'>: (0.5591700000000001, 0.010623375169878928)\n",
      "Pretrained,<class 'model.Bayesian.Bayesian'>: (0.5588500000000001, 0.010683281331126704)\n",
      "<class 'model.LR.LR'>,<class 'model.LR.LR'>: (0.31112500000000004, 0.010676931909495355)\n",
      "<class 'model.LR.LR'>,<class 'model.PGLR.PGLR'>: (0.4344250000000001, 0.010727155960458494)\n",
      "<class 'model.LR.LR'>,<class 'model.ProtoNet.ProtoNet'>: (0.43154500000000007, 0.010138563754299716)\n",
      "<class 'model.LR.LR'>,<class 'model.MAP.MAP'>: (0.4322650000000002, 0.008823960278695732)\n",
      "<class 'model.LR.LR'>,<class 'model.Bayesian.Bayesian'>: (0.433855, 0.008948545971273763)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.LR.LR'>: (0.19596000000000002, 0.03862840146834968)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.PGLR.PGLR'>: (0.5318550000000001, 0.011284789541679557)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.ProtoNet.ProtoNet'>: (0.568315, 0.010064952806645442)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.MAP.MAP'>: (0.5699150000000001, 0.010719854243412066)\n",
      "<class 'model.PGLR.PGLR'>,<class 'model.Bayesian.Bayesian'>: (0.570255, 0.010773693656309332)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.LR.LR'>: (0.05000499999999999, 4.974937185533104e-05)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.PGLR.PGLR'>: (0.4671000000000001, 0.013666930891754728)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.ProtoNet.ProtoNet'>: (0.5477050000000001, 0.009300025537599337)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.MAP.MAP'>: (0.553915, 0.009360303146800312)\n",
      "<class 'model.ProtoNet.ProtoNet'>,<class 'model.Bayesian.Bayesian'>: (0.5541050000000001, 0.009511386597126623)\n",
      "<class 'model.MAP.MAP'>,<class 'model.LR.LR'>: (0.23614999999999997, 0.04180959818032219)\n",
      "<class 'model.MAP.MAP'>,<class 'model.PGLR.PGLR'>: (0.56316, 0.010465868334734578)\n",
      "<class 'model.MAP.MAP'>,<class 'model.ProtoNet.ProtoNet'>: (0.5837950000000001, 0.00962265425961048)\n",
      "<class 'model.MAP.MAP'>,<class 'model.MAP.MAP'>: (0.587, 0.010344805459746429)\n",
      "<class 'model.MAP.MAP'>,<class 'model.Bayesian.Bayesian'>: (0.587205, 0.010216186910976123)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.LR.LR'>: (0.25983, 0.04019783700648581)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.PGLR.PGLR'>: (0.56637, 0.010976023870236435)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.ProtoNet.ProtoNet'>: (0.5871350000000001, 0.009546427342205031)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.MAP.MAP'>: (0.59187, 0.009366861801051608)\n",
      "<class 'model.Bayesian.Bayesian'>,<class 'model.Bayesian.Bayesian'>: (0.59179, 0.009330375126435145)\n"
     ]
    }
   ],
   "source": [
    "for key in results.keys():\n",
    "    for key2 in results[key].keys():\n",
    "        print (f\"{key},{key2}:\", results[key][key2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "desperate-saturday",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "Destination_Models = [LR, PGLR]\n",
    "\n",
    "results = {}\n",
    "for IID_LR in [\"iid\",\"not_iid\"]:\n",
    "    results[IID_LR] = {}\n",
    "    for IID_test in [\"iid_test\",\"not_iid_test\"]:\n",
    "        results[IID_LR][IID_test] = {}\n",
    "        for train_type in [\"Pretrained\",\"meta\"]:\n",
    "            results[IID_LR][IID_test][train_type] = {}\n",
    "            for Destination_Model in tqdm (Destination_Models, \"destination models\"):\n",
    "                print (f\"IID_lr:{IID_LR} train_type:{train_type}, IID_test:{IID_test} , Destination_Model:{Destination_Model}\")\n",
    "                \n",
    "                if train_type == \"meta\":\n",
    "                    Source_Backbone = Destination_Model\n",
    "                else:\n",
    "                    Source_Backbone = train_type\n",
    "                \n",
    "                if Destination_Model in Models_with_LR:\n",
    "                    lr = lrs[dataset_name][IID_LR][Destination_Model][Source_Backbone]\n",
    "                    print (f\"using lr{lr}\")\n",
    "                    params_train.inner_lr = lr\n",
    "\n",
    "                model = Destination_Model(params_train).to(device)\n",
    "                model.back_bone.load_state_dict (torch.load (models_paths[dataset_name][Source_Backbone]))\n",
    "                model = get_model_with_ml(model, params_train, dataset_train)\n",
    "                \n",
    "                if IID_test == \"iid_test\":\n",
    "                    IID_epochs = 5\n",
    "                else:\n",
    "                    IID_epochs = 0\n",
    "\n",
    "                if IID_test == \"iid_test\" and (Destination_Model in [Bayesian, ProtoNet, MAP]):\n",
    "                    continue\n",
    "                if IID_test == \"iid\" and (Destination_Model in [Bayesian, ProtoNet, MAP]):\n",
    "                    continue\n",
    "                \n",
    "                acc_mean, acc_std = evaluate(params_test, model, dataset_test, IID_epochs = IID_epochs,verbose=False)\n",
    "                print (acc_mean, acc_std)\n",
    "\n",
    "                results[IID_LR][IID_test][train_type][Destination_Model] = acc_mean, acc_std\n",
    "            \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "identical-sheriff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'iid': {'iid_test': {'Pretrained': {model.LR.LR: (0.7783666666666679,\n",
       "     0.007045960229489979),\n",
       "    model.PGLR.PGLR: (0.7459400000000009, 0.00760137122132882)},\n",
       "   'meta': {model.LR.LR: (0.8400800000000019, 0.006373054387201152),\n",
       "    model.PGLR.PGLR: (0.8684266666666693, 0.004898090330822014)}},\n",
       "  'not_iid_test': {'Pretrained': {model.LR.LR: (0.048699999999999986,\n",
       "     0.010356586739311786),\n",
       "    model.PGLR.PGLR: (0.6826400000000007, 0.006666946660786937)},\n",
       "   'meta': {model.LR.LR: (0.4662199999999994, 0.010636333954892514),\n",
       "    model.PGLR.PGLR: (0.7908400000000028, 0.006117094444623688)}}},\n",
       " 'not_iid': {'iid_test': {'Pretrained': {model.LR.LR: (0.40472000000000036,\n",
       "     0.007752092190714316),\n",
       "    model.PGLR.PGLR: (0.7417866666666673, 0.006206004261107938)},\n",
       "   'meta': {model.LR.LR: (0.7451133333333345, 0.005686869867412927),\n",
       "    model.PGLR.PGLR: (0.8641000000000024, 0.004618922433257257)}},\n",
       "  'not_iid_test': {'Pretrained': {model.LR.LR: (0.25712, 0.006519291031665652),\n",
       "    model.PGLR.PGLR: (0.6268666666666673, 0.008262095510354594)},\n",
       "   'meta': {model.LR.LR: (0.6389600000000006, 0.006858454636432297),\n",
       "    model.PGLR.PGLR: (0.7532400000000017, 0.007237568652523995)}}}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "defined-pitch",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'iid': {'iid_test': {'Pretrained': {model.LR.LR: (0.6095700000000001,\n",
       "     0.009812242353305398),\n",
       "    model.PGLR.PGLR: (0.56184, 0.010349125566925943)},\n",
       "   'meta': {model.LR.LR: (0.48890000000000006, 0.011025425161870163),\n",
       "    model.PGLR.PGLR: (0.57363, 0.0100308573910708)}},\n",
       "  'not_iid_test': {'Pretrained': {model.LR.LR: (0.07098, 0.019289105733548147),\n",
       "    model.PGLR.PGLR: (0.53263, 0.00979428915235811)},\n",
       "   'meta': {model.LR.LR: (0.22328, 0.014263646097684842),\n",
       "    model.PGLR.PGLR: (0.51403, 0.013859080056049909)}}},\n",
       " 'not_iid': {'iid_test': {'Pretrained': {model.LR.LR: (0.3845300000000001,\n",
       "     0.014307833518740718),\n",
       "    model.PGLR.PGLR: (0.5598200000000001, 0.010279474694749723)},\n",
       "   'meta': {model.LR.LR: (0.31499000000000005, 0.010583236744966074),\n",
       "    model.PGLR.PGLR: (0.57012, 0.009944123893033511)}},\n",
       "  'not_iid_test': {'Pretrained': {model.LR.LR: (0.32345, 0.028813408337091954),\n",
       "    model.PGLR.PGLR: (0.54132, 0.008282366811485696)},\n",
       "   'meta': {model.LR.LR: (0.30989, 0.010856007553424056),\n",
       "    model.PGLR.PGLR: (0.53323, 0.010574596919031972)}}}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abstract-survivor",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
