{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e929093a-5bdf-4568-9f37-95c51518fb1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a467a154-0d43-4d28-9811-8324e395bcd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_combinations_path = Path(\"/home/XXXX-1/projects/rep_merging/rep2rep/scripts/configs/model_combinations_layers_equally_spaced_large_sized_models.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f26e0bf0-a386-429d-ad86-b9b44700905f",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(model_combinations_path) as f:\n",
    "    model_combinations = [[m.strip() for m in line.split(\";\")] for line in f if line.strip()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "275622b0-92c5-4c9c-9a8e-44682327d71c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24\n",
      "24\n",
      "24\n",
      "OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.2.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.2.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.4.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.4.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.6.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.6.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.8.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.8.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.10.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.10.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.12.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.12.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.14.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.14.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.16.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.16.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.18.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.18.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.20.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.20.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual.transformer.resblocks.22.ln_2;OpenCLIP_ViT-L-14_openai_ap@visual.transformer.resblocks.22.ln_2;OpenCLIP_ViT-L-14_openai_cls@visual;OpenCLIP_ViT-L-14_openai_ap@visual\n",
      "dinov2-vit-large-p14_cls@blocks.2.norm2;dinov2-vit-large-p14_ap@blocks.2.norm2;dinov2-vit-large-p14_cls@blocks.4.norm2;dinov2-vit-large-p14_ap@blocks.4.norm2;dinov2-vit-large-p14_cls@blocks.6.norm2;dinov2-vit-large-p14_ap@blocks.6.norm2;dinov2-vit-large-p14_cls@blocks.8.norm2;dinov2-vit-large-p14_ap@blocks.8.norm2;dinov2-vit-large-p14_cls@blocks.10.norm2;dinov2-vit-large-p14_ap@blocks.10.norm2;dinov2-vit-large-p14_cls@blocks.12.norm2;dinov2-vit-large-p14_ap@blocks.12.norm2;dinov2-vit-large-p14_cls@blocks.14.norm2;dinov2-vit-large-p14_ap@blocks.14.norm2;dinov2-vit-large-p14_cls@blocks.16.norm2;dinov2-vit-large-p14_ap@blocks.16.norm2;dinov2-vit-large-p14_cls@blocks.18.norm2;dinov2-vit-large-p14_ap@blocks.18.norm2;dinov2-vit-large-p14_cls@blocks.20.norm2;dinov2-vit-large-p14_ap@blocks.20.norm2;dinov2-vit-large-p14_cls@blocks.22.norm2;dinov2-vit-large-p14_ap@blocks.22.norm2;dinov2-vit-large-p14_cls@norm;dinov2-vit-large-p14_ap@norm\n",
      "vit_large_patch16_224_cls@blocks.2.norm2;vit_large_patch16_224_ap@blocks.2.norm2;vit_large_patch16_224_cls@blocks.4.norm2;vit_large_patch16_224_ap@blocks.4.norm2;vit_large_patch16_224_cls@blocks.6.norm2;vit_large_patch16_224_ap@blocks.6.norm2;vit_large_patch16_224_cls@blocks.8.norm2;vit_large_patch16_224_ap@blocks.8.norm2;vit_large_patch16_224_cls@blocks.10.norm2;vit_large_patch16_224_ap@blocks.10.norm2;vit_large_patch16_224_cls@blocks.12.norm2;vit_large_patch16_224_ap@blocks.12.norm2;vit_large_patch16_224_cls@blocks.14.norm2;vit_large_patch16_224_ap@blocks.14.norm2;vit_large_patch16_224_cls@blocks.16.norm2;vit_large_patch16_224_ap@blocks.16.norm2;vit_large_patch16_224_cls@blocks.18.norm2;vit_large_patch16_224_ap@blocks.18.norm2;vit_large_patch16_224_cls@blocks.20.norm2;vit_large_patch16_224_ap@blocks.20.norm2;vit_large_patch16_224_cls@blocks.22.norm2;vit_large_patch16_224_ap@blocks.22.norm2;vit_large_patch16_224_cls@norm;vit_large_patch16_224_ap@norm\n",
      "\n"
     ]
    }
   ],
   "source": [
    "new_model_combinations  = []\n",
    "for model_set in model_combinations:\n",
    "    new_model_set = []\n",
    "    for i in range(2, len(model_set),4):\n",
    "        new_model_set += model_set[i:(i+2)]\n",
    "    print(len(new_model_set))\n",
    "    new_model_combinations.append(new_model_set)\n",
    "\n",
    "str_to_write = \"\"\n",
    "for new_model_set in new_model_combinations: \n",
    "        str_to_write += (\";\".join(new_model_set)+'\\n')\n",
    "\n",
    "print(str_to_write)\n",
    "\n",
    "with open(model_combinations_path, 'w') as f:\n",
    "    f.write(str_to_write)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2db9fb5e-27ae-4dc0-a10b-134864cd33f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "storing_path = Path(\"/home/XXXX-1/projects/rep_merging/rep2rep/scripts/configs\")\n",
    "\n",
    "path_to_config = \"/home/XXXX-1/projects/rep_merging/rep2rep/scripts/configs/models_config_single_model_layer_combination.json\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "40b180ab-fabd-460e-b556-a33c53234dd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(path_to_config) as f:\n",
    "    models = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c232a527-be36-4be3-8c5f-151878408663",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'model_name': 'vit_small_patch16_224.augreg_in21k_ft_in1k',\n",
       " 'source': 'timm',\n",
       " 'model_parameters': {'token_extraction': 'cls_token'},\n",
       " 'module_names': ['blocks.0.norm2',\n",
       "  'blocks.1.norm2',\n",
       "  'blocks.2.norm2',\n",
       "  'blocks.3.norm2',\n",
       "  'blocks.4.norm2',\n",
       "  'blocks.5.norm2',\n",
       "  'blocks.6.norm2',\n",
       "  'blocks.7.norm2',\n",
       "  'blocks.8.norm2',\n",
       "  'blocks.9.norm2',\n",
       "  'blocks.10.norm2',\n",
       "  'blocks.11.norm2',\n",
       "  'norm'],\n",
       " 'objective': 'Supervised',\n",
       " 'dataset': 'ImageNet21k + finetuned on ImageNet1k',\n",
       " 'architecture_class': 'Transformer',\n",
       " 'architecture': 'ViT',\n",
       " 'embedding_dim': 384,\n",
       " 'alignment': None,\n",
       " 'dataset_class': 'ImageNet21k + finetuned on ImageNet1k',\n",
       " 'size': 22050664,\n",
       " 'size_fmt': '22.1M',\n",
       " 'size_class': 'small',\n",
       " 'set_length': 13}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "models[next(iter(models))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4c639a28-9836-4727-906a-a96e18d1ac08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['OpenCLIP_ViT-B-16_openai',\n",
       " 'OpenCLIP_ViT-B-32_openai',\n",
       " 'OpenCLIP_ViT-L-14_openai',\n",
       " 'dinov2-vit-base-p14',\n",
       " 'dinov2-vit-large-p14',\n",
       " 'dinov2-vit-small-p14',\n",
       " 'vit_base_patch16_224',\n",
       " 'vit_large_patch16_224',\n",
       " 'vit_small_patch16_224']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_models = sorted(list(set([mid.rsplit(\"_\", 1)[0] for mid in models.keys()])))\n",
    "base_models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "751a7682-07c5-462a-a33f-e949d88d7022",
   "metadata": {},
   "source": [
    "## Create model combinations: cls + avg pool last layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8fdcaaf6-5e22-4e69-875e-2446b80dd823",
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5e400c58-b086-4d63-b644-f52ddecb2464",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, mid in enumerate(base_models):\n",
    "    last_module = models[f\"{mid}_cls\"]['module_names'][-1]\n",
    "    new_line = f\"{mid}_cls@{last_module};{mid}_ap@{last_module}\"\n",
    "    if i < (len(base_models)-1):\n",
    "        new_line += \"\\n\"\n",
    "    lines.append(new_line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d98405d1-e566-4822-b700-dde6c0ff9ccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "with (storing_path / \"model_combinations_last_layer_cls_ap.txt\").open(mode='w') as f:\n",
    "    f.writelines(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2024cd0d-6c6e-4ed5-921e-c753610deaaa",
   "metadata": {},
   "source": [
    "## Create model combinations different number of intermediate layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "38be1617-0f90-4270-ba81-17f2026579cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[\n",
      "  \"OpenCLIP_ViT-B-32_openai_cls@visual.transformer.resblocks.6.ln_2;OpenCLIP_ViT-B-32_openai_ap@visual.transformer.resblocks.6.ln_2;OpenCLIP_ViT-B-32_openai_cls@visual;OpenCLIP_ViT-B-32_openai_ap@visual\\n\",\n",
      "  \"OpenCLIP_ViT-B-32_openai_cls@visual.transformer.resblocks.3.ln_2;OpenCLIP_ViT-B-32_openai_ap@visual.transformer.resblocks.3.ln_2;OpenCLIP_ViT-B-32_openai_cls@visual.transformer.resblocks.6.ln_2;OpenCLIP_ViT-B-32_openai_ap@visual.transformer.resblocks.6.ln_2;OpenCLIP_ViT-B-32_openai_cls@visual.transformer.resblocks.9.ln_2;OpenCLIP_ViT-B-32_openai_ap@visual.transformer.resblocks.9.ln_2;OpenCLIP_ViT-B-32_openai_cls@visual;OpenCLIP_ViT-B-32_openai_ap@visual\\n\",\n",
      "  \"dinov2-vit-base-p14_cls@blocks.6.norm2;dinov2-vit-base-p14_ap@blocks.6.norm2;dinov2-vit-base-p14_cls@norm;dinov2-vit-base-p14_ap@norm\\n\",\n",
      "  \"dinov2-vit-base-p14_cls@blocks.3.norm2;dinov2-vit-base-p14_ap@blocks.3.norm2;dinov2-vit-base-p14_cls@blocks.6.norm2;dinov2-vit-base-p14_ap@blocks.6.norm2;dinov2-vit-base-p14_cls@blocks.9.norm2;dinov2-vit-base-p14_ap@blocks.9.norm2;dinov2-vit-base-p14_cls@norm;dinov2-vit-base-p14_ap@norm\\n\",\n",
      "  \"vit_base_patch16_224_cls@blocks.6.norm2;vit_base_patch16_224_ap@blocks.6.norm2;vit_base_patch16_224_cls@norm;vit_base_patch16_224_ap@norm\\n\",\n",
      "  \"vit_base_patch16_224_cls@blocks.3.norm2;vit_base_patch16_224_ap@blocks.3.norm2;vit_base_patch16_224_cls@blocks.6.norm2;vit_base_patch16_224_ap@blocks.6.norm2;vit_base_patch16_224_cls@blocks.9.norm2;vit_base_patch16_224_ap@blocks.9.norm2;vit_base_patch16_224_cls@norm;vit_base_patch16_224_ap@norm\"\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "base_sized_modes = ['OpenCLIP_ViT-B-32_openai', 'dinov2-vit-base-p14', 'vit_base_patch16_224']\n",
    "nr_layers_comb = [2, 4]\n",
    "lines = []\n",
    "cnt = 0\n",
    "for i, mid in enumerate(base_sized_modes):\n",
    "    curr_modules = models[f\"{mid}_cls\"]['module_names'][1:]\n",
    "    nr_layers = len(curr_modules)\n",
    "    for j in nr_layers_comb:\n",
    "        curr_layers_list = []\n",
    "        for k in range(j):\n",
    "            curr_layer = curr_modules[(nr_layers // j * (k+1)) - 1]\n",
    "            curr_layers_list.append(f\"{mid}_cls@{curr_layer}\")\n",
    "            curr_layers_list.append(f\"{mid}_ap@{curr_layer}\")\n",
    "        new_line = \";\".join(curr_layers_list)\n",
    "        if cnt < (len(base_sized_modes)*len(nr_layers_comb)-1):\n",
    "            new_line += \"\\n\"\n",
    "        lines.append(new_line)\n",
    "        cnt += 1\n",
    "\n",
    "print(json.dumps(lines, indent=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d77b5ddd-d780-40c7-8ac4-9975f6cdb04d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with (storing_path / \"model_combinations_layers_equally_spaced_base_sized_models.txt\").open(mode='w') as f:\n",
    "    f.writelines(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05ef8671-82a8-4e01-8910-8be43164ddbd",
   "metadata": {},
   "source": [
    "## Create model combinations: cls + avg pool all intermediate layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "0d82e0d3-01f6-4253-ad55-e6a17c6cc87b",
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "99371015-194d-449a-9cb5-2e6132966337",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, mid in enumerate(base_models):\n",
    "    new_line = []\n",
    "    for k, module_name in enumerate(models[f\"{mid}_cls\"]['module_names']):\n",
    "        if k == 0:\n",
    "            continue\n",
    "        new_line.append(f\"{mid}_cls@{module_name}\")\n",
    "        new_line.append(f\"{mid}_ap@{module_name}\")\n",
    "    new_line = \";\".join(new_line)\n",
    "    if i < (len(base_models)-1):\n",
    "        new_line += \"\\n\"\n",
    "    lines.append(new_line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f5cdaeb7-5224-48e2-a80b-9512170134ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "with (storing_path / \"model_combinations_layers_all_blocks_cls_ap.txt\").open(mode='w') as f:\n",
    "    f.writelines(lines)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "642a7cbf-5844-4918-bc24-b5b4e1c4f293",
   "metadata": {},
   "source": [
    "## Create model combinations: all tokens last layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "577405b5-bd73-41c7-8a89-fe9ac3a067ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "lines = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "bb0abc53-75ff-4441-aea1-7cfa5178ad2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, mid in enumerate(base_models):\n",
    "    last_module = models[f\"{mid}_cls\"]['module_names'][-1]\n",
    "    new_line = f\"{mid}_at@{last_module}\"\n",
    "    if i < (len(base_models)-1):\n",
    "        new_line += \"\\n\"\n",
    "    lines.append(new_line)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "464ee006-6a40-465a-a5a9-045f644e6cd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "with (storing_path / \"model_combinations_all_tokens_last_layer.txt\").open(mode='w') as f:\n",
    "    f.writelines(lines)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
