{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d461464c-a93f-44b3-b76c-8cb6a42b2b15",
   "metadata": {},
   "source": [
    "# Deep InfoMax representation learning for images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c22d6fc-a528-40d2-b457-716ebccd6dc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../python\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5eaf72-207e-44bd-bec1-02ba7b0bcdac",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchkld\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "963a91d7-6ac6-493c-8b43-5d151bd1fb8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import infomax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52e593dd-add1-4e4d-9621-6a3baea6c751",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "#device = \"cpu\"\n",
    "print(\"Device: \" + device)\n",
    "print(f\"Devices count: {torch.cuda.device_count()}\")\n",
    "print(f\"CUDA version: {torch.version.cuda}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9008d6e7-7282-44d0-9069-8651fa1a443a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from misc.modules import *\n",
    "from misc.plots import *\n",
    "from misc.training import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe561a8b-cbfe-4e8b-8e33-73a3378a3474",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "path = Path(\"../../data/\").resolve()\n",
    "experiments_path = path / \"embeddings/\"\n",
    "#models_path = experiments_path / \"models/\"\n",
    "#results_path = experiments_path / \"resuts/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7d55996-a1c5-40df-a169-a87bc9872b61",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f86efe6-86de-46af-b6a3-2711d4e14706",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76090058-8c80-4be5-811c-351db16f333a",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_transform = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44b97b1f-9abb-4d46-868b-4dcf03b0cdbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"dataset\"] = \"MNIST\"\n",
    "#config[\"dataset\"] = \"CIFAR10\"\n",
    "config[\"n_classes\"] = 10\n",
    "\n",
    "train_dataset = getattr(torchvision.datasets, config[\"dataset\"])(root=\"./.cache\", download=True, transform=image_transform)\n",
    "test_dataset  = getattr(torchvision.datasets, config[\"dataset\"])(root=\"./.cache\", download=True, transform=image_transform, train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b73ff5a-9193-4b1d-8a14-c584b5c65e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"batch_size_train\"] = 1024\n",
    "config[\"batch_size_test\"]  = 1024\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config[\"batch_size_train\"], shuffle=True)\n",
    "test_dataloader  = torch.utils.data.DataLoader(test_dataset, batch_size=config[\"batch_size_test\"], shuffle=False)\n",
    "eval_dataloader  = test_dataloader #torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d6ec80c-7f62-45d8-b191-d7c271a466d9",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35568128-6f85-4db6-9236-a8d59ca4f0c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"distribution\"] = \"normal\"\n",
    "#config[\"distribution\"] = \"uniform\"\n",
    "\n",
    "config[\"embedding_dim\"] = 2\n",
    "normalization_layer = torch.nn.BatchNorm1d(config[\"embedding_dim\"], affine=False) if config[\"distribution\"] == \"normal\" else torch.nn.Sigmoid()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "348692ce-1941-47ba-ba38-9ee09cb7cde4",
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"backbone\"] = \"convnet\"\n",
    "\n",
    "if config[\"backbone\"] == \"convnet\":\n",
    "    backbone = Conv2dEmbedder(embedding_dim=config[\"embedding_dim\"])\n",
    "else:\n",
    "    backbone = getattr(torchvision.models, config[\"backbone\"])(num_classes=config[\"embedding_dim\"]).train()\n",
    "    \n",
    "    if config[\"dataset\"] in [\"CIFAR10\", \"CIFAR100\"]:\n",
    "        backbone.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)\n",
    "        backbone.maxpool = torch.nn.Identity()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ad6e2c5-1973-4a32-8985-bc0364b038d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedder_network = torch.nn.Sequential(\n",
    "    backbone,\n",
    "    normalization_layer\n",
    ").to(device)\n",
    "embedder_network.embedding_dim = config[\"embedding_dim\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc6dfb88-0928-4f1b-89e5-66279e98ae1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"discriminator_network\"] = \"DenseT\"\n",
    "config[\"discriminator_network_inner_dim\"] = 256\n",
    "config[\"discriminator_network_output_dim\"] = 256\n",
    "\n",
    "_discriminator_network_factory = {\n",
    "    \"SeparableT\": lambda: SeparableT(\n",
    "        config[\"embedding_dim\"],\n",
    "        config[\"embedding_dim\"],\n",
    "        inner_dim=config[\"discriminator_network_inner_dim\"],\n",
    "        output_dim=config[\"discriminator_network_output_dim\"],\n",
    "    ).to(device),\n",
    "    \"DenseT\": lambda: DenseT(\n",
    "        config[\"embedding_dim\"],\n",
    "        config[\"embedding_dim\"],\n",
    "        inner_dim=config[\"discriminator_network_inner_dim\"]\n",
    "    ).to(device),\n",
    "    \"AdditiveGaussainT\": lambda: AdditiveGaussainT(p=0.99).to(device)\n",
    "}\n",
    "\n",
    "discriminator_network = _discriminator_network_factory[config[\"discriminator_network\"]]()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94d7fcb0-c7fb-43e4-be75-ded99ba48c05",
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"input_p\"]  = 5.0e-1\n",
    "config[\"output_p\"] = 1.0e-2\n",
    "\n",
    "model = infomax.embeddings.Embedder(\n",
    "    embedder_network,\n",
    "    discriminator_network,\n",
    "    infomax.channels.BoundedVarianceGaussianChannel(config[\"input_p\"]),\n",
    "    #torchvision.transforms.Compose([\n",
    "    #    torchvision.transforms.RandomResizedCrop((32, 32), scale=(0.2, 1.)),\n",
    "    #    torchvision.transforms.RandomHorizontalFlip(),\n",
    "    #    torchvision.transforms.RandomApply([\n",
    "    #        torchvision.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened\n",
    "    #        #torchvision.transforms.ColorJitter(0.5, 0.5, 0.5, 0.5)  # strengthened\n",
    "    #    ], p=0.8),\n",
    "    #    torchvision.transforms.RandomGrayscale(p=0.2),\n",
    "    #    #infomax.channels.BoundedVarianceGaussianChannel(config[\"input_p\"]).to(device)\n",
    "    #    #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),\n",
    "    #]),\n",
    "    infomax.channels.BoundedVarianceGaussianChannel(config[\"output_p\"]) if config[\"distribution\"] == \"normal\" else infomax.channels.BoundedSupportUniformChannel(config[\"output_p\"]),\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14a03aad-3a13-4834-9ef7-7e2869796cc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "config[\"capacity\"] = config[\"embedding_dim\"] * model.output_channel.capacity\n",
    "config[\"min_capacity_for_classification\"] = math.log(config[\"n_classes\"])\n",
    "\n",
    "print(f\"Capacity: {config['capacity']:.2f}\")\n",
    "print(f\"Min capacity required for class preservation: {config['min_capacity_for_classification']:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bc5ff5b-5ae8-445b-a48c-b54c44ef647e",
   "metadata": {},
   "outputs": [],
   "source": [
    "config[\"n_epochs\"] = 2001\n",
    "config[\"embedder_network_lr\"] = 1.0e-3\n",
    "config[\"discriminator_network_lr\"] = 1.0e-3\n",
    "\n",
    "config[\"loss\"] = \"InfoNCELoss\"\n",
    "config[\"marginalize\"] = \"product\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0e92007-8c74-4dc1-863c-2b41b4debb99",
   "metadata": {},
   "outputs": [],
   "source": [
    "history = train_infomax_embedder(\n",
    "    model,\n",
    "    train_dataloader,\n",
    "    test_dataloader,\n",
    "    device,\n",
    "    callback=lambda history, epoch, step, infomax_embedder, train_dataloader, test_dataloader, device: classification_callback(\n",
    "        history, epoch, step, infomax_embedder, train_dataloader, test_dataloader, device,\n",
    "        #period=20,\n",
    "        #distribution_tests={},\n",
    "        #clustering_metrics={},\n",
    "        #classifiers={\n",
    "        #    \"logistic_regression\": lambda: DenseClassifier(config[\"embedding_dim\"], config[\"n_classes\"], device).to(device),\n",
    "        #    #\"mlp\": lambda: DenseClassifier(config[\"embedding_dim\"], config[\"n_classes\"], device, n_layers=3).to(device),\n",
    "        #    #\"knn\": lambda: KNeighborsClassifier(metric='cosine'),\n",
    "        #    #\"mlp\": lambda: MLPClassifier(alpha=1.0, max_iter=1000),\n",
    "        #},\n",
    "    ),\n",
    "    optimizer_embedder_network=lambda params: torch.optim.Adam(params, lr=config[\"embedder_network_lr\"]),\n",
    "    optimizer_discriminator_network=lambda params: torch.optim.Adam(params, lr=config[\"discriminator_network_lr\"]),\n",
    "    loss=getattr(torchkld.loss, config[\"loss\"])(),\n",
    "    marginalize=config[\"marginalize\"],\n",
    "    distribution=config[\"distribution\"],\n",
    "    n_epochs=config[\"n_epochs\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e35647c2-19ba-49a6-860d-fffe7f17baac",
   "metadata": {},
   "outputs": [],
   "source": [
    "if config[\"discriminator_network\"] == \"AdditiveGaussainT\":\n",
    "    torch.sigmoid(discriminator_network.p_logit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b65a1d8-8c9f-4c61-8253-adadba9aa27e",
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings = convert_to_embeddings(embedder_network, train_dataloader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f34a7679-8013-4697-9f0e-81382cccc63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.special import ndtr\n",
    "\n",
    "#config[\"distribution\"] = \"uniform\"\n",
    "xy_lim = (-3.0, 3.0) if config[\"distribution\"] == \"normal\" else (-0.1, 1.1)\n",
    "\n",
    "#plot_embeddings(ndtr(embeddings[0]), embeddings[1], size=4, alpha=1.0, x_lim=xy_lim, y_lim=xy_lim)\n",
    "plot_embeddings(*embeddings, size=4, alpha=1.0, x_lim=xy_lim, y_lim=xy_lim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b82743c1-e40b-4e1f-bcd1-0ef4282e43f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "save_results(model, config, history, experiments_path / config[\"dataset\"] / config[\"discriminator_network\"] / str(config[\"embedding_dim\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e849c2f-8a99-47c4-a7f6-2a4bd0c4b540",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
