{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train transformations between latent spaces"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device in use: cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import itertools\n",
    "import os\n",
    "import pandas as pd\n",
    "from torch import nn\n",
    "from torch.optim import Adam\n",
    "from tqdm import tqdm\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts\n",
    "from pytorch_lightning import seed_everything\n",
    "from torch.nn.functional import cosine_similarity\n",
    "from pathlib import Path\n",
    "from datasets.dataset_dict import DatasetDict\n",
    "from latent_invariances.utils import transformations\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(\"Device in use:\", device)\n",
    "\n",
    "WEIGHTS_DIR = PROJECT_ROOT / \"data\" / \"weights\"\n",
    "EMBEDDINGS_DIR = PROJECT_ROOT / \"data\" / \"embeddings\"\n",
    "\n",
    "OUT_DIR = PROJECT_ROOT / \"notebooks\" / \"transformations\" / \"results\" /\n",
    "\n",
    "DATASET = \"mnist\"\n",
    "\n",
    "SPACES_PATH = Path(EMBEDDINGS_DIR / DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "ABS_DIR = \"/root/\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize_transformation(fix_space, to_transf_space, transformation, lr, num_opt_steps, enable_scheduler):\n",
    "    loss = nn.MSELoss()\n",
    "    loss_history = []\n",
    "\n",
    "    opt = Adam(transformation.parameters(), lr=lr)\n",
    "\n",
    "    if enable_scheduler:\n",
    "        scheduler = CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=2, eta_min=0, last_epoch=-1, verbose=False)\n",
    "\n",
    "    for i in (var := tqdm(range(num_opt_steps), desc=\"Optimization\")):\n",
    "        transformed_space = transformation(to_transf_space)\n",
    "        out = loss(transformed_space, fix_space)\n",
    "\n",
    "        loss_history.append(out.detach().cpu())\n",
    "        out.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "\n",
    "        if enable_scheduler:\n",
    "            scheduler.step()\n",
    "\n",
    "        var.set_description(f\"mse: {out.detach().cpu():.7f}\")\n",
    "\n",
    "    opt_dict = {\n",
    "        \"loss\": out.detach().cpu(),\n",
    "        \"to_transf_space\": to_transf_space,\n",
    "        \"fix_space\": fix_space,\n",
    "        \"transformed_space\": transformed_space.detach().cpu(),\n",
    "    }\n",
    "\n",
    "    return opt_dict\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_cosine_matches(space_x, space_y):\n",
    "    return cosine_similarity(space_x.detach().cpu(), space_y.detach().cpu()).mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = [\n",
    "    \"SEED\",\n",
    "    \"PATH\",\n",
    "    \"TRANSFORMATION\",\n",
    "    \"DATASET\",\n",
    "    \"SPACE_X\",\n",
    "    \"SPACE_Y\",\n",
    "    \"TRANSFORMED_SPACE\",\n",
    "    \"NUM_OPT_STEPS\",\n",
    "    \"LR\",\n",
    "    \"SCHEDULER\",\n",
    "    \"LOSS\",\n",
    "    \"COSINE_SIM\",\n",
    "    \"TIMESTAMP\",\n",
    "]\n",
    "\n",
    "info = {key: [] for key in columns}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_steps_list = [5000]\n",
    "lr_list = [0.02]\n",
    "scheduler_enabled = [False]\n",
    "seed_list = [0]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['ae_run0'], ['ae_run1'])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spaces = DatasetDict.load_from_disk(str(SPACES_PATH))[\"train\"]\n",
    "\n",
    "space_x_ids = list(spaces[0].keys())[2:3]  # [2:-5]\n",
    "space_y_ids = list(spaces[0].keys())[3:4]  # [2:-5]\n",
    "\n",
    "space_x_ids, space_y_ids\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# shape = torch.tensor(spaces[space_x_ids[0]]).shape[1]\n",
    "shape = 768\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRANSFORMATION_DICT = {\n",
    "    \"isotropic_scaling\": transformations.IsotropicScaling(),\n",
    "    \"translation\": transformations.Translation(shape),\n",
    "    \"linear_transformation\": transformations.LinearTransformation(shape),\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[LinearTransformation(), OrthogonalTransformation()]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformations = list(TRANSFORMATION_DICT.values())\n",
    "transformations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for (\n",
    "    opt_id,\n",
    "    (\n",
    "        transformation,\n",
    "        space_x_id,\n",
    "        space_y_id,\n",
    "        seed,\n",
    "        lr,\n",
    "        opt_steps,\n",
    "        scheduler,\n",
    "    ),\n",
    ") in enumerate(\n",
    "    itertools.product(\n",
    "        transformations,\n",
    "        space_x_ids,\n",
    "        space_y_ids,\n",
    "        seed_list,\n",
    "        lr_list,\n",
    "        opt_steps_list,\n",
    "        scheduler_enabled,\n",
    "    )\n",
    "):\n",
    "    if space_x_id != space_y_id:\n",
    "        seed_everything(seed)\n",
    "\n",
    "        space_x = torch.tensor(spaces[space_x_id])\n",
    "        space_y = torch.tensor(spaces[space_y_id])\n",
    "\n",
    "        assert len(space_x) == len(space_y)\n",
    "\n",
    "        print(f\"Optimization {transformation} - {space_x_id} -> {space_y_id}\")\n",
    "\n",
    "        opt_dict = optimize_transformation(\n",
    "            fix_space=space_x,\n",
    "            to_transf_space=space_y,\n",
    "            transformation=transformation,\n",
    "            lr=lr,\n",
    "            num_opt_steps=opt_steps,\n",
    "            enable_scheduler=scheduler_enabled,\n",
    "        )\n",
    "\n",
    "        info[\"SEED\"].append(seed)\n",
    "        info[\"PATH\"].append(SPACES_PATH)\n",
    "        info[\"TRANSFORMATION\"].append(transformation.__class__.__name__)\n",
    "        info[\"DATASET\"].append(DATASET)\n",
    "        info[\"SPACE_X\"].append(space_x_id)\n",
    "        info[\"SPACE_Y\"].append(space_y_id)\n",
    "        info[\"TRANSFORMED_SPACE\"].append(space_y_id)\n",
    "        info[\"NUM_OPT_STEPS\"].append(opt_steps)\n",
    "        info[\"LR\"].append(lr)\n",
    "        info[\"SCHEDULER\"].append(scheduler)\n",
    "        info[\"LOSS\"].append(opt_dict[\"loss\"].item())\n",
    "        info[\"COSINE_SIM\"].append(\n",
    "            calc_cosine_matches(space_x=opt_dict[\"fix_space\"], space_y=opt_dict[\"transformed_space\"]).item()\n",
    "        )\n",
    "        info[\"TIMESTAMP\"].append(pd.to_datetime(\"now\").replace(microsecond=0))\n",
    "\n",
    "        results = pd.DataFrame(info)\n",
    "        results.to_csv(\"optimization_results\", index=False)\n",
    "\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
