{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from latent_invariances.stitching import DatasetConfig, data_config, test\n",
    "from datasets import DatasetDict, load_from_disk\n",
    "from latent_invariances.modules.aggregation_modules import LinearSelfAttentionLayer, LayerNorm, NonLinearSumAggregation\n",
    "from latent_invariances.utils.relreps import SIMPLE_PROJECTION_TYPE\n",
    "from latent_invariances.modules.simple_classifier import Classifier, SVCModel\n",
    "from latent_invariances.utils.relreps import SIMPLE_PROJECTION_TYPE\n",
    "import torch\n",
    "from typing import Callable, Mapping, Sequence, Tuple, Type\n",
    "from pytorch_lightning import Trainer, seed_everything\n",
    "from pytorch_lightning.callbacks import EarlyStopping\n",
    "from torch.utils.data import DataLoader\n",
    "from latent_invariances.utils.space import LatentSpace\n",
    "from typing import Mapping\n",
    "from torch import nn\n",
    "from typing import Type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_datataset(dataset: str) -> DatasetDict:\n",
    "    dataset_config: DatasetConfig = data_config(dataset)\n",
    "    data: DatasetDict = load_from_disk(dataset_path=str(dataset_config.directory))\n",
    "    if dataset_config.key.startswith(\"dbpedia_14\"):\n",
    "        data = DatasetDict(\n",
    "            train=data[\"train\"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)[\n",
    "                \"train\"\n",
    "            ],\n",
    "            test=data[\"test\"].train_test_split(train_size=0.1, stratify_by_column=dataset_config.label_column)[\"train\"],\n",
    "        )\n",
    "    tensor_columns = {\n",
    "        column\n",
    "        for column in data[\"train\"].column_names\n",
    "        if any(column.startswith(encoder) for encoder in dataset_config.encoders)\n",
    "    }\n",
    "    tensor_columns.add(dataset_config.label_column)\n",
    "    data.set_format(columns=tensor_columns, output_all_columns=True, type=\"torch\")\n",
    "\n",
    "    fit_data = data[\"train\"].train_test_split(train_size=0.9, seed=42, stratify_by_column=dataset_config.label_column)\n",
    "    train_data, val_data, test_data = fit_data[\"train\"], fit_data[\"test\"], data[\"test\"]\n",
    "    data = DatasetDict({\"train\": train_data, \"val\": val_data, \"test\": test_data})\n",
    "    data.num_classes = train_data.features[dataset_config.label_column].num_classes\n",
    "    data.dataset_config = dataset_config\n",
    "    return data\n",
    "\n",
    "\n",
    "def load_enc_name2abs_space(data: DatasetDict) -> Mapping[str, Mapping[str, DatasetDict]]:\n",
    "    enc_name2abs_space: Mapping[str, Mapping[str, LatentSpace]] = {\n",
    "        enc_name: {\n",
    "            split: LatentSpace(\n",
    "                encoding_type=\"absolute\",\n",
    "                vectors=data[split][data.dataset_config.enc_name2column[enc_name]],\n",
    "                encoder=enc_name,\n",
    "                keys=data[split][\"index\"],\n",
    "                labels=data[split][data.dataset_config.label_column],\n",
    "                num_classes=data.num_classes,\n",
    "            )\n",
    "            for split in [\"train\", \"val\", \"test\"]\n",
    "        }\n",
    "        for enc_name in data.dataset_config.encoders\n",
    "    }\n",
    "    return enc_name2abs_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "class RelativeBlock(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        anchors: torch.Tensor,\n",
    "        projection_names: Sequence[str],\n",
    "        aggregation_module: torch.nn.Module,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.register_buffer(\"anchors\", anchors)\n",
    "        self.projection_names = projection_names\n",
    "        self.projection_funcs: Sequence[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = [\n",
    "            SIMPLE_PROJECTION_TYPE[x] for x in projection_names\n",
    "        ]\n",
    "        self.aggregation_module = aggregation_module\n",
    "\n",
    "    def get_attention_weights(self, encoded_x: torch.Tensor, attention_idx: int):\n",
    "        return self.aggregation_module.get_attention_weights(encoded_x, attention_idx=attention_idx)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        rel_x = self.encode(x, self.anchors)\n",
    "        return self.decode(rel_x)\n",
    "\n",
    "    def encode(self, x, anchors: Optional[torch.Tensor] = None) -> torch.Tensor:\n",
    "        if anchors is None:\n",
    "            anchors = self.anchors\n",
    "        return torch.cat([fun(anchors=anchors, points=x) for fun in self.projection_funcs], dim=-1)\n",
    "\n",
    "    def decode(self, rel_x: torch.Tensor) -> torch.Tensor:\n",
    "        return self.aggregation_module(rel_x)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"RelativeBlock({self.projection_names}, {self.aggregation_module})\"\n",
    "\n",
    "    def __str__(self):\n",
    "        return f\"RelativeBlock({self.projection_names=}, {self.aggregation_module=})\"\n",
    "\n",
    "\n",
    "def get_model(\n",
    "    anchors: torch.Tensor,\n",
    "    projection_names: Sequence[str],\n",
    "    aggregation_module: Type[torch.nn.Module],\n",
    "    num_anchors: int,\n",
    "    lr: float,\n",
    "    seed: int,\n",
    "    num_classes: int,\n",
    "):\n",
    "    aggregator = aggregation_module(\n",
    "        subspace_dim=num_anchors,\n",
    "        num_subspaces=len(projection_names),\n",
    "    )\n",
    "\n",
    "    relative_block = RelativeBlock(\n",
    "        anchors=anchors,\n",
    "        projection_names=projection_names,\n",
    "        aggregation_module=aggregator,\n",
    "    )\n",
    "\n",
    "    model = Classifier(\n",
    "        aggregation_module=relative_block,\n",
    "        input_dim=aggregator.out_dim,\n",
    "        num_classes=num_classes,\n",
    "        lr=lr,\n",
    "        deep=\"linear\",\n",
    "        seed=seed,\n",
    "    )\n",
    "\n",
    "    return model\n",
    "\n",
    "\n",
    "class StitchedModel(torch.nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        encoder: torch.nn.Module,\n",
    "        decoder: torch.nn.Module,\n",
    "        encoder_name: str,\n",
    "        decoder_name: str,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder.eval()\n",
    "        self.decoder = decoder.eval()\n",
    "        self.encoder_name = encoder_name\n",
    "        self.decoder_name = decoder_name\n",
    "\n",
    "        self.x_feature = self.encoder.x_feature\n",
    "        self.y_feature = self.encoder.y_feature\n",
    "\n",
    "    def get_attention_weights(self, x: torch.Tensor, attention_idx: int):\n",
    "        rel_x = self.encoder.encode(x)\n",
    "        return self.decoder.get_attention_weights(rel_x, attention_idx=attention_idx)\n",
    "\n",
    "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
    "        rel_x = self.encoder.encode(x)\n",
    "        return self.decoder.decode(rel_x)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return f\"StitchedModel({self.encoder}, {self.decoder})\"\n",
    "\n",
    "\n",
    "def plot_attention_weights(\n",
    "    stitched_model: StitchedModel,\n",
    "    enc_name2abs_space,\n",
    "    batch_size,\n",
    "    pin_memory,\n",
    "    num_workers,\n",
    "    persistent_workers,\n",
    "    device,\n",
    "    attention_index: int = 0,\n",
    "):\n",
    "    train_loader = DataLoader(\n",
    "        enc_name2abs_space[stitched_model.encoder_name][\"train\"],\n",
    "        shuffle=True,\n",
    "        batch_size=batch_size,\n",
    "        pin_memory=pin_memory,\n",
    "        num_workers=num_workers,\n",
    "        persistent_workers=persistent_workers,\n",
    "    )\n",
    "\n",
    "    ws = []\n",
    "    for batch in train_loader:\n",
    "        x = batch[stitched_model.x_feature].to(device)\n",
    "        _, w = stitched_model.get_attention_weights(x, attention_index)\n",
    "        ws.append(w.detach().cpu())\n",
    "    w = torch.cat(ws, dim=0).mean(0)\n",
    "\n",
    "    fig = sns.heatmap(\n",
    "        w.detach().cpu().numpy(),\n",
    "        cmap=\"Reds\",\n",
    "        annot=True,\n",
    "        fmt=\".2f\",\n",
    "        xticklabels=stitched_model.encoder.relative_block.projection_names,\n",
    "        yticklabels=stitched_model.encoder.relative_block.projection_names,\n",
    "    )\n",
    "    return fig, w"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fit(\n",
    "    projection_names: Sequence[str],\n",
    "    aggregation_module: Type[nn.Module],\n",
    "    num_classes: int,\n",
    "    num_anchors: int,\n",
    "    split2abs_space: Mapping[str, DatasetDict],\n",
    "    device: torch.device,\n",
    "    seed: int,\n",
    "    lr: float,\n",
    "    epochs: int,\n",
    "    batch_size: int,\n",
    "    num_workers: int,\n",
    "    pin_memory: bool,\n",
    "    persistent_workers: bool,\n",
    "):\n",
    "    seed_everything(seed)\n",
    "    anchors = split2abs_space[\"train\"].get_anchors(anchor_choice=\"uniform\", seed=0, num_anchors=num_anchors)\n",
    "    anchors = torch.stack(list(anchors.values()))  # FIXME: this returning a dict here\n",
    "    train_loader = DataLoader(\n",
    "        split2abs_space[\"train\"],\n",
    "        batch_size=batch_size,\n",
    "        pin_memory=pin_memory,\n",
    "        shuffle=True,\n",
    "        num_workers=num_workers,\n",
    "        persistent_workers=persistent_workers,\n",
    "    )\n",
    "    val_loader = DataLoader(\n",
    "        split2abs_space[\"val\"], batch_size=batch_size, pin_memory=pin_memory, shuffle=False, num_workers=num_workers\n",
    "    )\n",
    "    test_loader = DataLoader(\n",
    "        split2abs_space[\"test\"], batch_size=3000, pin_memory=pin_memory, shuffle=False, num_workers=8\n",
    "    )\n",
    "\n",
    "    model = get_model(\n",
    "        anchors=anchors,\n",
    "        projection_names=projection_names,\n",
    "        aggregation_module=aggregation_module,\n",
    "        num_anchors=num_anchors,\n",
    "        lr=lr,\n",
    "        seed=seed,\n",
    "        num_classes=num_classes,\n",
    "    ).to(device)\n",
    "\n",
    "    trainer = Trainer(\n",
    "        accelerator=\"auto\",\n",
    "        devices=1,\n",
    "        max_epochs=epochs,\n",
    "        logger=None,\n",
    "        check_val_every_n_epoch=10,\n",
    "        callbacks=[\n",
    "            EarlyStopping(\n",
    "                monitor=\"accuracy\",\n",
    "                verbose=True,\n",
    "                patience=1,\n",
    "                mode=\"max\",\n",
    "            )\n",
    "        ],\n",
    "        enable_progress_bar=True,\n",
    "    )\n",
    "    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)\n",
    "\n",
    "    # test the model\n",
    "    test_info = test(\n",
    "        num_classes=num_classes,\n",
    "        test_loader=test_loader,\n",
    "        model=model.eval(),\n",
    "    )\n",
    "\n",
    "    return dict(model=model.eval().cpu(), info=test_info)\n",
    "\n",
    "\n",
    "def test_model(\n",
    "    model,\n",
    "    split2abs_space: Mapping[str, DatasetDict],\n",
    "    num_classes,\n",
    "    pin_memory=True,\n",
    "):\n",
    "    test_loader = DataLoader(\n",
    "        split2abs_space[\"test\"], batch_size=3000, pin_memory=pin_memory, shuffle=False, num_workers=8\n",
    "    )\n",
    "    test_info = test(\n",
    "        num_classes=num_classes,\n",
    "        test_loader=test_loader,\n",
    "        model=model.eval(),\n",
    "    )\n",
    "\n",
    "    return test_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = \"cifar100\"\n",
    "data = load_datataset(dataset)\n",
    "enc_name2abs_space = load_enc_name2abs_space(data)\n",
    "enc_name2abs_space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def normalize_name(name) -> str:\n",
    "    if isinstance(name, torch.nn.Module):\n",
    "        return name.__class__.__name__.lower()\n",
    "    if isinstance(name, Type):\n",
    "        return name.__name__.lower()\n",
    "    if isinstance(name, Sequence):\n",
    "        return \",\".join(x.lower() for x in sorted(name))\n",
    "    return name.lower()\n",
    "\n",
    "\n",
    "def get_run_name(aggregator: Type, projection_names: Sequence[str]):\n",
    "    run_name = f\"{normalize_name(aggregator)}({normalize_name(projection_names)})\"\n",
    "    return run_name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train models end2end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import itertools\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "\n",
    "ENCODER = (\n",
    "    # \"vit_base_patch16_224\",\n",
    "    \"rexnet_100\",\n",
    "    \"vit_base_patch16_384\",\n",
    "    # \"vit_small_patch16_224\",\n",
    "    # \"vit_base_resnet50_384\",\n",
    "    # \"openai/clip-vit-base-patch32\",\n",
    ")\n",
    "SEEDS = [0]\n",
    "PROJECTION_NAMES_LIST = [\n",
    "    ([\"Cosine\", \"Euclidean\", \"L1\", \"Linf\"], LinearSelfAttentionLayer),\n",
    "    ([\"Cosine\", \"Euclidean\", \"L1\", \"Linf\"], NonLinearSumAggregation),\n",
    "    ([\"Cosine\"], LayerNorm),\n",
    "    ([\"Euclidean\"], LayerNorm),\n",
    "    ([\"L1\"], LayerNorm),\n",
    "    ([\"Linf\"], LayerNorm),\n",
    "]\n",
    "NUM_ANCHORS = 1280\n",
    "LR = 1e-3\n",
    "EPOCHS = 50\n",
    "BATCH_SIZE = 10000\n",
    "NUM_WORKERS = 0\n",
    "PIN_MEMORY = True\n",
    "PERSISTENT_WORKERS = False\n",
    "\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "runname2encoder2seed2info = defaultdict(dict)\n",
    "results = []\n",
    "for (projection_names, aggregation_module), seed, encoder in tqdm(\n",
    "    list(itertools.product(PROJECTION_NAMES_LIST, SEEDS, ENCODER))\n",
    "):\n",
    "    run_name = get_run_name(aggregator=aggregation_module, projection_names=projection_names)\n",
    "    out = fit(\n",
    "        projection_names=projection_names,\n",
    "        aggregation_module=aggregation_module,\n",
    "        num_classes=data.num_classes,\n",
    "        split2abs_space=enc_name2abs_space[encoder],\n",
    "        device=device,\n",
    "        seed=seed,\n",
    "        num_anchors=NUM_ANCHORS,\n",
    "        batch_size=BATCH_SIZE,\n",
    "        num_workers=NUM_WORKERS,\n",
    "        lr=LR,\n",
    "        epochs=EPOCHS,\n",
    "        pin_memory=PIN_MEMORY,\n",
    "        persistent_workers=PERSISTENT_WORKERS,\n",
    "    )\n",
    "\n",
    "    if run_name not in runname2encoder2seed2info:\n",
    "        runname2encoder2seed2info[run_name] = defaultdict(dict)\n",
    "    runname2encoder2seed2info[run_name][encoder][seed] = out\n",
    "    results.append(\n",
    "        {\n",
    "            \"runname\": run_name,\n",
    "            \"seed\": seed,\n",
    "            \"score\": out[\"info\"][\"score\"],\n",
    "            \"encoder\": encoder,\n",
    "        }\n",
    "    )\n",
    "end2end_results = pd.DataFrame(results)\n",
    "end2end_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Measure stitching performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []\n",
    "for (projection_names, aggregation_module), seed, encoder_name, decoder_name in tqdm(\n",
    "    list(itertools.product(PROJECTION_NAMES_LIST, SEEDS, ENCODER, ENCODER))\n",
    "):\n",
    "    run_name = get_run_name(aggregator=aggregation_module, projection_names=projection_names)\n",
    "\n",
    "    model1 = runname2encoder2seed2info[run_name][encoder_name][seed][\"model\"].to(device).eval()\n",
    "    model2 = runname2encoder2seed2info[run_name][decoder_name][seed][\"model\"].to(device).eval()\n",
    "    stitched_model = StitchedModel(\n",
    "        encoder=model1,\n",
    "        decoder=model2,\n",
    "        encoder_name=encoder_name,\n",
    "        decoder_name=decoder_name,\n",
    "    )\n",
    "    out = test_model(\n",
    "        model=stitched_model,\n",
    "        split2abs_space=enc_name2abs_space[encoder_name],\n",
    "        num_classes=data.num_classes,\n",
    "    )\n",
    "    results.append(\n",
    "        {\n",
    "            \"runname\": run_name,\n",
    "            \"seed\": seed,\n",
    "            \"score\": out[\"score\"],\n",
    "            \"encoder_name\": encoder_name,\n",
    "            \"decoder_name\": decoder_name,\n",
    "        }\n",
    "    )\n",
    "stitching_results = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stitching_results.groupby([\"encoder_name\", \"decoder_name\", \"runname\"]).agg(\"mean\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select encoder and decoder to stitch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ATTENTION_INDEX = 0\n",
    "SEED = 0\n",
    "ENCODER_1 = \"rexnet_100\"\n",
    "ENCODER_2 = \"vit_base_patch16_384\"\n",
    "AGGREGATOR = LinearSelfAttentionLayer\n",
    "PROJECTION_NAMES = [\"Cosine\", \"Euclidean\", \"L1\", \"Linf\"]\n",
    "\n",
    "###\n",
    "\n",
    "run_name = get_run_name(aggregator=AGGREGATOR, projection_names=PROJECTION_NAMES)\n",
    "model1 = runname2encoder2seed2info[run_name][ENCODER_1][SEED][\"model\"].to(device).eval()\n",
    "model2 = runname2encoder2seed2info[run_name][ENCODER_2][SEED][\"model\"].to(device).eval()\n",
    "stitched_model = StitchedModel(\n",
    "    encoder=model1,\n",
    "    decoder=model2,\n",
    "    encoder_name=ENCODER_1,\n",
    "    decoder_name=ENCODER_2,\n",
    ")\n",
    "\n",
    "print(\n",
    "    test_model(\n",
    "        model=stitched_model,\n",
    "        split2abs_space=enc_name2abs_space[ENCODER_1],\n",
    "        num_classes=data.num_classes,\n",
    "    )\n",
    ")\n",
    "stitching_results[\n",
    "    (stitching_results.encoder_name == stitched_model.encoder_name)\n",
    "    & (stitching_results.decoder_name == stitched_model.decoder_name)\n",
    "].groupby([\"encoder_name\", \"decoder_name\", \"runname\"]).agg(\"mean\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot Attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, w_orig = plot_attention_weights(\n",
    "    stitched_model=stitched_model,\n",
    "    enc_name2abs_space=enc_name2abs_space,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    pin_memory=PIN_MEMORY,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    persistent_workers=PERSISTENT_WORKERS,\n",
    "    device=device,\n",
    "    attention_index=ATTENTION_INDEX,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimize attention weights for this couple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # # Debug `attention.in_proj_weight`, it has shape [3*embed_dim, embed_dim]\n",
    "# # # but the order of qkv in the first dimension is undocumented.\n",
    "\n",
    "# from torch import nn\n",
    "# import torch\n",
    "# import math\n",
    "# embed_dim = 3\n",
    "\n",
    "# query = torch.arange(math.prod([1, 2, embed_dim])).reshape([1, 2, embed_dim]).float()\n",
    "# keys = torch.arange(math.prod([1, 2, embed_dim])).reshape([1, 2, embed_dim]).float()\n",
    "# values = torch.ones([1, 2, embed_dim]).float() * 2\n",
    "\n",
    "\n",
    "# q_mask = slice(None, 1*embed_dim)\n",
    "# k_mask = slice(1*embed_dim, 2*embed_dim)\n",
    "# v_mask = slice(2*embed_dim ,None)\n",
    "\n",
    "# with torch.no_grad():\n",
    "#     attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=1, dropout=0.0, bias=False, batch_first=True)\n",
    "#     attention.out_proj.weight.fill_(0).add_(torch.eye(attention.out_proj.weight.shape[0]))\n",
    "#     attention.in_proj_weight[q_mask, :] = 1\n",
    "#     attention.in_proj_weight[k_mask, :] = 2\n",
    "#     attention.in_proj_weight[v_mask, :] = 0\n",
    "#     out, w = attention(query=query, key=keys, value=values)\n",
    "\n",
    "# attention.in_proj_weight, out, w"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimize only key, value and query projection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "from itertools import chain\n",
    "\n",
    "import copy\n",
    "\n",
    "OPT_EPOCHS = 50\n",
    "LR = 1e-4\n",
    "\n",
    "stitched_model_qkv_opt = copy.deepcopy(stitched_model).to(device)\n",
    "optimizer = torch.optim.Adam(\n",
    "    params=chain(\n",
    "        (\n",
    "            stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_weight,\n",
    "            stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_bias,\n",
    "        )\n",
    "    ),\n",
    "    lr=LR,\n",
    ")\n",
    "\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    enc_name2abs_space[stitched_model_qkv_opt.encoder_name][\"train\"],\n",
    "    batch_size=BATCH_SIZE,\n",
    "    pin_memory=PIN_MEMORY,\n",
    "    shuffle=True,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    persistent_workers=PERSISTENT_WORKERS,\n",
    ")\n",
    "\n",
    "\n",
    "# q_mask = slice(None, 1 * NUM_ANCHORS)\n",
    "# k_mask = slice(1 * NUM_ANCHORS, 2 * NUM_ANCHORS)\n",
    "# v_mask = slice(2 * NUM_ANCHORS, None)\n",
    "\n",
    "\n",
    "def train(model, batch, optimizer):\n",
    "    x = batch[model.x_feature].to(device)\n",
    "    y = batch[model.y_feature].to(device)\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    logits = model(x)\n",
    "    loss = F.cross_entropy(logits, y)\n",
    "    loss.backward()\n",
    "\n",
    "    # # Optimize Key and Query only\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_weight.grad[v_mask, :] = 0\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_bias.grad[v_mask] = 0\n",
    "\n",
    "    # # Optimize Key only\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_weight.grad[q_mask, :] = 0\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_weight.grad[v_mask, :] = 0\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_bias.grad[q_mask] = 0\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_bias.grad[v_mask] = 0\n",
    "\n",
    "    # # Optimize Value only\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_weight.grad[q_mask, :] = 0\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_weight.grad[k_mask, :] = 0\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_bias.grad[q_mask] = 0\n",
    "    # stitched_model_qkv_opt.decoder.relative_block.aggregation_module.attention.in_proj_bias.grad[k_mask] = 0\n",
    "\n",
    "    optimizer.step()\n",
    "    return float(loss)\n",
    "\n",
    "\n",
    "for epoch in (bar := tqdm(range(OPT_EPOCHS))):\n",
    "    for batch in train_loader:\n",
    "        loss = train(stitched_model_qkv_opt, batch, optimizer)\n",
    "\n",
    "    if epoch % 10 == 0:\n",
    "        test_score = test_model(\n",
    "            model=stitched_model_qkv_opt,\n",
    "            split2abs_space=enc_name2abs_space[stitched_model_qkv_opt.encoder_name],\n",
    "            num_classes=data.num_classes,\n",
    "        )[\"score\"]\n",
    "\n",
    "    bar.set_description(f\"[{epoch}: {loss:.4f}] (test: {test_score:.4f})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plot Attention once optimized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, w_qkv_opt = plot_attention_weights(\n",
    "    stitched_model=stitched_model_qkv_opt,\n",
    "    enc_name2abs_space=enc_name2abs_space,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    pin_memory=PIN_MEMORY,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    persistent_workers=PERSISTENT_WORKERS,\n",
    "    device=device,\n",
    "    attention_index=ATTENTION_INDEX,\n",
    ")\n",
    "qkv_opt_score = test_model(\n",
    "    model=stitched_model_qkv_opt,\n",
    "    split2abs_space=enc_name2abs_space[stitched_model_qkv_opt.encoder_name],\n",
    "    num_classes=data.num_classes,\n",
    ")[\"score\"]\n",
    "qkv_opt_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimize classifier "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "from itertools import chain\n",
    "\n",
    "import copy\n",
    "\n",
    "OPT_EPOCHS = 50\n",
    "LR = 1e-4\n",
    "\n",
    "stitched_model_classifier_opt = copy.deepcopy(stitched_model).to(device)\n",
    "#\n",
    "optimizer = torch.optim.Adam(\n",
    "    params=stitched_model_classifier_opt.decoder.class_proj.parameters(),\n",
    "    lr=LR,\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    enc_name2abs_space[stitched_model_classifier_opt.encoder_name][\"train\"],\n",
    "    batch_size=BATCH_SIZE,\n",
    "    pin_memory=PIN_MEMORY,\n",
    "    shuffle=True,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    persistent_workers=PERSISTENT_WORKERS,\n",
    ")\n",
    "\n",
    "\n",
    "q_mask = slice(None, 1 * NUM_ANCHORS)\n",
    "k_mask = slice(1 * NUM_ANCHORS, 2 * NUM_ANCHORS)\n",
    "v_mask = slice(2 * NUM_ANCHORS, None)\n",
    "\n",
    "\n",
    "def train(model, batch, optimizer):\n",
    "    x = batch[model.x_feature].to(device)\n",
    "    y = batch[model.y_feature].to(device)\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    logits = model(x)\n",
    "    loss = F.cross_entropy(logits, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return float(loss)\n",
    "\n",
    "\n",
    "for epoch in (bar := tqdm(range(OPT_EPOCHS))):\n",
    "    for batch in train_loader:\n",
    "        loss = train(stitched_model_classifier_opt, batch, optimizer)\n",
    "\n",
    "    if epoch % 10 == 0:\n",
    "        test_score = test_model(\n",
    "            model=stitched_model_classifier_opt,\n",
    "            split2abs_space=enc_name2abs_space[stitched_model_classifier_opt.encoder_name],\n",
    "            num_classes=data.num_classes,\n",
    "        )[\"score\"]\n",
    "\n",
    "    bar.set_description(f\"[{epoch}: {loss:.4f}] (test: {test_score:.4f})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Plot Attention once optimized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, w_classifier_opt = plot_attention_weights(\n",
    "    stitched_model=stitched_model_classifier_opt,\n",
    "    enc_name2abs_space=enc_name2abs_space,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    pin_memory=PIN_MEMORY,\n",
    "    num_workers=NUM_WORKERS,\n",
    "    persistent_workers=PERSISTENT_WORKERS,\n",
    "    device=device,\n",
    "    attention_index=ATTENTION_INDEX,\n",
    ")\n",
    "classifier_opt_score = test_model(\n",
    "    model=stitched_model_classifier_opt,\n",
    "    split2abs_space=enc_name2abs_space[stitched_model_classifier_opt.encoder_name],\n",
    "    num_classes=data.num_classes,\n",
    ")[\"score\"]\n",
    "classifier_opt_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.allclose(w_orig, w_classifier_opt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nn_core.common import PROJECT_ROOT\n",
    "\n",
    "results_dir = PROJECT_ROOT / \"results_paper\" / \"attention_opt\"\n",
    "results_dir.mkdir(exist_ok=True, parents=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "end2end_results.to_csv(results_dir / \"end2end_results.tsv\", sep=\"\\t\", index=False)\n",
    "stitching_results.to_csv(results_dir / \"stitching_results.tsv\", sep=\"\\t\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "orig_score = test_model(\n",
    "    model=stitched_model,\n",
    "    split2abs_space=enc_name2abs_space[ENCODER_1],\n",
    "    num_classes=data.num_classes,\n",
    ")[\"score\"]\n",
    "opt_results = pd.DataFrame(\n",
    "    {\n",
    "        \"opt_mode\": [\"qkv_opt_score\", \"classifier_opt_score\", \"none\"],\n",
    "        \"score\": [qkv_opt_score, classifier_opt_score, orig_score],\n",
    "        \"seed\": [SEED, SEED, SEED],\n",
    "        \"encoder_name\": [ENCODER_1, ENCODER_1, ENCODER_1],\n",
    "        \"decoder_name\": [ENCODER_2, ENCODER_2, ENCODER_2],\n",
    "        \"runname\": [run_name, run_name, run_name],\n",
    "        \"epochs\": [OPT_EPOCHS, OPT_EPOCHS, 0],\n",
    "        \"lr\": [LR, LR, 0],\n",
    "        \"dataset\": [dataset, dataset, dataset],\n",
    "    }\n",
    ")\n",
    "opt_results.to_csv(results_dir / \"opt_results.tsv\", sep=\"\\t\", index=False)\n",
    "opt_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(w_orig, results_dir / \"w_orig.pt\")\n",
    "torch.save(w_qkv_opt, results_dir / \"w_qkv_opt.pt\")\n",
    "torch.save(w_classifier_opt, results_dir / \"w_classifier_opt.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "(results_dir / \"projection_names_order.json\").write_text(\n",
    "    json.dumps(stitched_model.encoder.relative_block.projection_names, indent=4)\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent-invariances",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
