{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets.utils.base_dataset import BaseDataset, KAND_get_loader\n",
    "from datasets.utils.kand_creation import miniKAND_Dataset\n",
    "from backbones.kand_encoder import  TripleMLP\n",
    "import time\n",
    "\n",
    "class MiniKandinsky(BaseDataset):\n",
    "    NAME = 'minikandinsky'\n",
    "\n",
    "    def get_data_loaders(self):\n",
    "        start = time.time()\n",
    "\n",
    "        dataset_train = miniKAND_Dataset(base_path='data/kand-3k',split='train') \n",
    "        dataset_val   = miniKAND_Dataset(base_path='data/kand-3k',split='val')      \n",
    "        dataset_test  = miniKAND_Dataset(base_path='data/kand-3k',split='test') \n",
    "        # dataset_ood   = KAND_Dataset(base_path='data/kandinsky/data',split='ood') \n",
    "\n",
    "        # dataset_train.mask_concepts('red-and-squares')\n",
    "\n",
    "        print(f'Loaded datasets in {time.time()-start} s.')        \n",
    "\n",
    "        print('Len loaders: \\n train:', len(dataset_train), '\\n val:', len(dataset_val))\n",
    "        print(' len test:', len(dataset_test)) #, '\\n len ood', len(dataset_ood))\n",
    "\n",
    "        \n",
    "        train_loader = KAND_get_loader(dataset_train, 64, val_test=True)\n",
    "        val_loader   = KAND_get_loader(dataset_val,   500, val_test=True)\n",
    "        test_loader  = KAND_get_loader(dataset_test,  500, val_test=True)\n",
    "        \n",
    "        # self.ood_loader = get_loader(dataset_ood,  self.args.batch_size, val_test=True)\n",
    "\n",
    "        return train_loader, val_loader, test_loader\n",
    "\n",
    "    def get_backbone(self):\n",
    "        return TripleMLP(latent_dim=6), 0\n",
    "        #return TripleCNNEncoder(latent_dim=6), 0\n",
    "    \n",
    "    def get_split(self):\n",
    "        return 3, ()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [],
   "source": [
    "dset = MiniKandinsky(None)\n",
    "encoder, _ = dset.get_backbone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_train = miniKAND_Dataset(base_path='data/kand-3k',split='train') \n",
    "dataset_val   = miniKAND_Dataset(base_path='data/kand-3k',split='val')      \n",
    "dataset_test  = miniKAND_Dataset(base_path='data/kand-3k',split='test') \n",
    "\n",
    "# dataset_train.list_images = dataset_train.list_images[0]\n",
    "# dataset_train.concepts = dataset_train.concepts[0]\n",
    "# dataset_train.labels = dataset_train.labels[0]\n",
    "\n",
    "train_loader = KAND_get_loader(dataset_train, 64, val_test=True)\n",
    "val_loader   = KAND_get_loader(dataset_val,   500, val_test=True)\n",
    "test_loader  = KAND_get_loader(dataset_test,  500, val_test=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(dataset_train.list_images )\n",
    "# print(dataset_train.concepts )\n",
    "# print(dataset_train.labels )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = encoder.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim import Adam, SGD\n",
    "\n",
    "opt = Adam(encoder.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "backbone.1.weight\n",
      "backbone.1.bias\n",
      "backbone.4.weight\n",
      "backbone.4.bias\n",
      "backbone.7.weight\n",
      "backbone.7.bias\n"
     ]
    }
   ],
   "source": [
    "for name, param in encoder.named_parameters():\n",
    "    print(name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 122,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0 Iter: 0 Loss: 1.1249303817749023\n",
      "Epoch 1 Iter: 0 Loss: 0.004875645507127047\n",
      "Epoch 2 Iter: 0 Loss: 0.00699184276163578\n",
      "Epoch 3 Iter: 0 Loss: 0.005085902288556099\n",
      "Epoch 4 Iter: 0 Loss: 0.005486843176186085\n",
      "Epoch 5 Iter: 0 Loss: 0.004140877164900303\n",
      "Epoch 6 Iter: 0 Loss: 0.005515359342098236\n",
      "Epoch 7 Iter: 0 Loss: 0.00797035451978445\n",
      "Epoch 8 Iter: 0 Loss: 0.0092710480093956\n",
      "Epoch 9 Iter: 0 Loss: 0.005397433415055275\n",
      "Epoch 10 Iter: 0 Loss: 0.007073335349559784\n",
      "Epoch 11 Iter: 0 Loss: 0.009108457714319229\n",
      "Epoch 12 Iter: 0 Loss: 0.011702841147780418\n",
      "Epoch 13 Iter: 0 Loss: 0.006490848958492279\n",
      "Epoch 14 Iter: 0 Loss: 0.009613808244466782\n",
      "Epoch 15 Iter: 0 Loss: 0.0061861551366746426\n",
      "Epoch 16 Iter: 0 Loss: 0.006716057658195496\n",
      "Epoch 17 Iter: 0 Loss: 0.006654723547399044\n",
      "Epoch 18 Iter: 0 Loss: 0.008405258879065514\n",
      "Epoch 19 Iter: 0 Loss: 0.007115951273590326\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(20):\n",
    "    for iter, data in enumerate(train_loader):\n",
    "        img, label, concept = data\n",
    "        img, label, concept = img.to(device), label.to(device), concept.to(device)\n",
    "\n",
    "        opt.zero_grad()\n",
    "        \n",
    "        xs = torch.split(img, 28*3, dim=-1)\n",
    "        preds = []\n",
    "        for i in range(len(xs)):\n",
    "            out, _ = encoder(xs[i])    \n",
    "            preds.append(out.unsqueeze(1))\n",
    "        preds = torch.cat(preds, dim=1)\n",
    "\n",
    "        # print(iter)\n",
    "\n",
    "        loss = 0\n",
    "        for i in range(3):\n",
    "            c = preds[:, i, :]\n",
    "            g = concept[:, i, :]\n",
    "\n",
    "            cs = torch.split(c, 3, dim=-1)\n",
    "            gs = torch.split(g, 1, dim=-1)\n",
    "\n",
    "            assert len(cs) == len(gs), (cs[0].shape, gs[0].shape)\n",
    "            \n",
    "            # for j in range(3):\n",
    "            j = 2\n",
    "            # loss += torch.nn.functional.cross_entropy(cs[2*j],   gs[j].view(-1)) / 3\n",
    "            loss += torch.nn.functional.cross_entropy(cs[2*j+1], gs[3+j].view(-1)) / 3\n",
    "\n",
    "        \n",
    "        if iter % 100 == 0:\n",
    "            print('Epoch',epoch, 'Iter:', iter, 'Loss:', loss.item())\n",
    "        \n",
    "        loss.backward()\n",
    "        opt.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Adam (\n",
       "Parameter Group 0\n",
       "    amsgrad: False\n",
       "    betas: (0.9, 0.999)\n",
       "    capturable: False\n",
       "    differentiable: False\n",
       "    eps: 1e-08\n",
       "    foreach: None\n",
       "    fused: None\n",
       "    lr: 0.001\n",
       "    maximize: False\n",
       "    weight_decay: 0\n",
       ")"
      ]
     },
     "execution_count": 61,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rss",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
