{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3afba055",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e2fde5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from pathlib import Path\n",
    "from rae import PROJECT_ROOT\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "from pytorch_lightning import seed_everything\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bddf3583",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModel, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a60d9f72",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.attention import RelativeAttention, AttentionOutput"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f2829d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device: str = \"cuda\"\n",
    "target_key: str = \"label\"\n",
    "data_key: str = \"image\"\n",
    "dataset_name: str = \"imagenet-1k\"\n",
    "num_anchors: int = 768\n",
    "train_perc: float = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54fdaa84",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, ClassLabel\n",
    "\n",
    "\n",
    "def get_dataset(split: str, perc: float):\n",
    "    seed_everything(42)\n",
    "    assert 0 < perc <= 1\n",
    "    dataset = load_dataset(dataset_name)[split]\n",
    "\n",
    "    # Select a random subset\n",
    "    indices = list(range(len(dataset)))\n",
    "    random.shuffle(indices)\n",
    "    indices = indices[: int(len(indices) * perc)]\n",
    "    dataset = dataset.select(indices)\n",
    "\n",
    "    def clean_sample(sample):\n",
    "        return sample\n",
    "\n",
    "    #     dataset = dataset.map(clean_sample)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49e4af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = get_dataset(split=\"train\", perc=train_perc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a41a58f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class2idx = train_dataset.features[target_key].str2int\n",
    "train_dataset.features[target_key].num_classes, len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b78c6553",
   "metadata": {},
   "outputs": [],
   "source": [
    "class2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd815a1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "from timm.data import resolve_data_config\n",
    "from timm.data.transforms_factory import create_transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981b6fa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoFeatureExtractor, AutoModelForImageClassification, AutoModel\n",
    "\n",
    "\n",
    "def load_transformer(transformer_name):\n",
    "    transformer = timm.create_model(transformer_name, pretrained=True, num_classes=0)\n",
    "    return transformer.requires_grad_(False).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1a0b9f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = get_dataset(split=\"validation\", perc=train_perc)\n",
    "len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6beba308",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def call_transformer(batch, transformer):\n",
    "    #     batch[\"encoding\"] = batch[\"encoding\"].to(device)\n",
    "    sample_encodings = transformer(batch[\"encoding\"].to(device))\n",
    "    #     hidden = sample_encodings[\"hidden_states\"][-1]\n",
    "    #     assert hidden.size(-1) == hidden.size(-2), hidden.size()\n",
    "    #     print(sample_encodings.shape)\n",
    "    return {\"hidden\": sample_encodings}\n",
    "\n",
    "\n",
    "#     hidden = F.avg_pool2d(hidden, hidden.size(-1))\n",
    "\n",
    "#     return {\"hidden\": hidden[:, 0, :].flatten(1).squeeze(), \"logits\": sample_encodings[\"logits\"]}\n",
    "\n",
    "# TODO: aggregation mode\n",
    "# result = []\n",
    "# for sample_encoding, sample_mask in zip(sample_encodings, batch[\"mask\"]):\n",
    "#     result.append(sample_encoding[sample_mask].mean(dim=0))\n",
    "\n",
    "# return torch.stack(result, dim=0)\n",
    "#     return sample_encodings[:, 0, :]  # CLS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fbfb2a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# t = load_transformer(transformer_name=\"vit_base_patch16_224\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c588ce56",
   "metadata": {},
   "outputs": [],
   "source": [
    "# config = resolve_data_config({}, model=t)\n",
    "# transform = create_transform(**config)\n",
    "# call_transformer(collate_fn(train_dataset.select(range(2)), None, transform), t.to(device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3649cedf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "\n",
    "\n",
    "assert num_anchors <= len(train_dataset)\n",
    "\n",
    "seed_everything(42)\n",
    "anchor_idxs = list(range(len(train_dataset)))\n",
    "random.shuffle(anchor_idxs)\n",
    "anchor_idxs = anchor_idxs[:num_anchors]\n",
    "\n",
    "anchor_dataset = train_dataset.select(anchor_idxs)\n",
    "len(anchor_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f4b15a",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_names = list(\n",
    "    {\n",
    "        #     \"google/vit-base-patch16-224\",\n",
    "        \"vit_base_patch16_224\",\n",
    "        \"vit_small_patch16_224\",\n",
    "        \"vit_base_resnet50_384\",\n",
    "        \"rexnet_100\",\n",
    "        #         \"regnetx_002\"\n",
    "        #     \"nvidia/mit-b0\",\n",
    "        #     \"nvidia/mit-b2\",\n",
    "        #     \"nvidia/mit-b3\",\n",
    "        #     \"facebook/vit-mae-base\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290719be",
   "metadata": {},
   "outputs": [],
   "source": [
    "# relative_projection = RelativeAttention(\n",
    "#     n_anchors=num_anchors,\n",
    "#     normalization_mode=\"l2\",\n",
    "#     similarity_mode=\"inner\",\n",
    "#     values_mode=\"similarities\",\n",
    "#     n_classes=train_dataset.features[target_key].num_classes,\n",
    "#     output_normalization_mode=None,\n",
    "# ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a88f2425",
   "metadata": {},
   "outputs": [],
   "source": [
    "def relative_projection(x, anchors):\n",
    "    x = F.normalize(x, p=2, dim=-1)\n",
    "    anchors = F.normalize(anchors, p=2, dim=-1)\n",
    "    return torch.einsum(\"bm, am -> ba\", x, anchors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7809500",
   "metadata": {},
   "outputs": [],
   "source": [
    "# dummy_x = torch.randn(32, 512, 16, 16)\n",
    "# dummy_anchors = torch.randn(42, 512)\n",
    "# relative_projection(x=dummy_x, anchors=dummy_anchors).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "397f6d06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(batch, feature_extractor, transform):\n",
    "    #     encoding = feature_extractor(\n",
    "    #         [sample[data_key] for sample in batch],\n",
    "    #         return_tensors=\"pt\",\n",
    "    #     )\n",
    "    #     encoding = {\"pixel_values\" : torch.stack([transform(sample['image'].convert(\"RGB\")) for sample in batch], dim=0)}\n",
    "    # mask = encoding[\"attention_mask\"] * encoding[\"special_tokens_mask\"].bool().logical_not()\n",
    "    # return {\"encoding\": encoding, \"mask\": mask.bool()}\n",
    "    return {\"encoding\": torch.stack([transform(sample[\"image\"].convert(\"RGB\")) for sample in batch], dim=0)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a3d7735",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:\n",
    "    absolute_latents: List = []\n",
    "    relative_latents: List = []\n",
    "    #     logits_latents: List = []\n",
    "\n",
    "    transformer = transformer.to(device)\n",
    "    for batch in tqdm(dataloader, desc=f\"[{split}] Computing latents\"):\n",
    "        with torch.no_grad():\n",
    "            transformer_out = call_transformer(batch=batch, transformer=transformer)\n",
    "\n",
    "            #             logits_latents.append(transformer_out[\"logits\"].cpu())\n",
    "            absolute_latents.append(transformer_out[\"hidden\"].cpu())\n",
    "\n",
    "            if anchors is not None:\n",
    "                batch_rel_latents = relative_projection(x=transformer_out[\"hidden\"], anchors=anchors)\n",
    "                relative_latents.append(batch_rel_latents.cpu())\n",
    "\n",
    "    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0).cpu()\n",
    "    #     logits_latents: torch.Tensor = torch.cat(logits_latents, dim=0).cpu()\n",
    "    relative_latents: torch.Tensor = (\n",
    "        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents\n",
    "    )\n",
    "\n",
    "    transformer = transformer.cpu()\n",
    "    return {\n",
    "        \"absolute\": absolute_latents,\n",
    "        \"relative\": relative_latents,\n",
    "        #         \"logits\": logits_latents\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f89a145",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae import PROJECT_ROOT\n",
    "\n",
    "LATENTS_DIR: Path = PROJECT_ROOT / \"data\" / \"latents\" / \"imagenet\" / str(train_perc)\n",
    "LATENTS_DIR.mkdir(exist_ok=True, parents=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f5398a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_latents(split: str, transformer_names: Sequence[str]):\n",
    "    transformer2latents = {}\n",
    "\n",
    "    for transformer_name in transformer_names:\n",
    "        transformer_path = LATENTS_DIR / split / f\"{transformer_name.replace('/', '-')}.pt\"\n",
    "        if transformer_path.exists():\n",
    "            transformer2latents[transformer_name] = torch.load(transformer_path)\n",
    "\n",
    "    return transformer2latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d29b727",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "from torchvision.transforms import (\n",
    "    CenterCrop,\n",
    "    Compose,\n",
    "    Normalize,\n",
    "    Resize,\n",
    "    ToTensor,\n",
    ")\n",
    "\n",
    "\n",
    "def encode_latents(transformer_names: Sequence[str], dataset, transformer_name2latents, split: str):\n",
    "    for transformer_name in transformer_names:\n",
    "        transformer = load_transformer(transformer_name=transformer_name)\n",
    "        config = resolve_data_config({}, model=transformer)\n",
    "        transform = create_transform(**config)\n",
    "        transformer_name2latents[transformer_name] = {\n",
    "            \"anchors_latents\": (\n",
    "                anchors_latents := get_latents(\n",
    "                    dataloader=DataLoader(\n",
    "                        anchor_dataset,\n",
    "                        num_workers=4,\n",
    "                        pin_memory=True,\n",
    "                        collate_fn=partial(collate_fn, feature_extractor=None, transform=transform),\n",
    "                        batch_size=32,\n",
    "                    ),\n",
    "                    split=f\"{transformer_name}, anchor, {split}\",\n",
    "                    anchors=None,\n",
    "                    transformer=transformer,\n",
    "                )[\"absolute\"]\n",
    "            ),\n",
    "            **get_latents(\n",
    "                dataloader=DataLoader(\n",
    "                    dataset,\n",
    "                    num_workers=4,\n",
    "                    pin_memory=True,\n",
    "                    collate_fn=partial(collate_fn, feature_extractor=None, transform=transform),\n",
    "                    batch_size=32,\n",
    "                ),\n",
    "                split=f\"{split}/{transformer_name}\",\n",
    "                anchors=anchors_latents.to(device),\n",
    "                transformer=transformer,\n",
    "            ),\n",
    "        }\n",
    "        # Save latents\n",
    "        if CACHE_LATENTS:\n",
    "            transformer_path = LATENTS_DIR / split / f\"{transformer_name.replace('/', '-')}.pt\"\n",
    "            transformer_path.parent.mkdir(exist_ok=True, parents=True)\n",
    "            torch.save(transformer_name2latents[transformer_name], transformer_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b41b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute test latents\n",
    "\n",
    "FORCE_RECOMPUTE: bool = False\n",
    "CACHE_LATENTS: bool = True\n",
    "\n",
    "transformer2test_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(\n",
    "    split=\"test\", transformer_names=transformer_names\n",
    ")\n",
    "missing_transformers = (\n",
    "    transformer_names\n",
    "    if FORCE_RECOMPUTE\n",
    "    else [t_name for t_name in transformer_names if t_name not in transformer2test_latents]\n",
    ")\n",
    "encode_latents(\n",
    "    transformer_names=missing_transformers,\n",
    "    dataset=test_dataset,\n",
    "    transformer_name2latents=transformer2test_latents,\n",
    "    split=\"test\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74cb3ae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute train latents\n",
    "\n",
    "FORCE_RECOMPUTE: bool = False\n",
    "CACHE_LATENTS: bool = True\n",
    "\n",
    "transformer2train_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(\n",
    "    split=\"train\", transformer_names=transformer_names\n",
    ")\n",
    "missing_transformers = (\n",
    "    transformer_names\n",
    "    if FORCE_RECOMPUTE\n",
    "    else [t_name for t_name in transformer_names if t_name not in transformer2train_latents]\n",
    ")\n",
    "encode_latents(\n",
    "    transformer_names=missing_transformers,\n",
    "    dataset=train_dataset,\n",
    "    transformer_name2latents=transformer2train_latents,\n",
    "    split=\"train\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc6b963",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_name2hidden_dim = {\n",
    "    transformer_name: latents[\"absolute\"][0].shape[0] for transformer_name, latents in transformer2train_latents.items()\n",
    "}\n",
    "transformer_name2hidden_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd41a8b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_normalize: bool = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943bb3a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import Adam\n",
    "\n",
    "\n",
    "# def fit(X, y, seed, **kwargs):\n",
    "#     classifier = make_pipeline(\n",
    "#         Normalizer(), StandardScaler(), SVC(gamma=\"auto\", kernel=\"linear\", max_iter=200, random_state=seed)\n",
    "#     )  # , class_weight=\"balanced\"))\n",
    "#     classifier.fit(X, y)\n",
    "#     return lambda x: classifier.predict(x)\n",
    "\n",
    "\n",
    "class Lambda(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super().__init__()\n",
    "        self.func = func\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.func(x)\n",
    "\n",
    "\n",
    "def fit(X: torch.Tensor, y, seed, normalize: bool, hidden_dim: int):\n",
    "    seed_everything(seed)\n",
    "    if normalize:\n",
    "        X = F.normalize(X, p=2, dim=-1)\n",
    "    dataset = TensorDataset(X, torch.as_tensor(y))\n",
    "    loader = DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)\n",
    "\n",
    "    model = nn.Sequential(\n",
    "        nn.LayerNorm(normalized_shape=hidden_dim),\n",
    "        nn.Linear(in_features=hidden_dim, out_features=num_anchors),\n",
    "        nn.SiLU(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=num_anchors),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=num_anchors, out_features=num_anchors),\n",
    "        nn.SiLU(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=num_anchors),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=num_anchors, out_features=train_dataset.features[target_key].num_classes),\n",
    "    ).to(device)\n",
    "    opt = Adam(model.parameters(), lr=1e-4)\n",
    "    loss_fn = CrossEntropyLoss()\n",
    "    for epoch in tqdm(range(1), leave=False, desc=\"epoch\"):\n",
    "        for batch_x, batch_y in loader:\n",
    "            batch_x = batch_x.to(device)\n",
    "            batch_y = batch_y.to(device)\n",
    "            pred_y = model(batch_x)\n",
    "            loss = loss_fn(pred_y, batch_y)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            opt.zero_grad()\n",
    "    model = model.cpu().eval()\n",
    "    return lambda x: model(x).argmax(-1).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7951c369",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEEDS = list(range(3))\n",
    "train_classifiers = {\n",
    "    seed: {\n",
    "        embedding_type: {\n",
    "            transformer_name: fit(\n",
    "                train_latents[embedding_type],\n",
    "                train_dataset[target_key],\n",
    "                seed=seed,\n",
    "                normalize=latent_normalize,\n",
    "                hidden_dim=transformer_name2hidden_dim[transformer_name]\n",
    "                if embedding_type == \"absolute\"\n",
    "                else num_anchors,\n",
    "            )\n",
    "            #             if embedding_type == \"relative\"\n",
    "            #             else fake_model()\n",
    "            for transformer_name, train_latents in tqdm(\n",
    "                transformer2train_latents.items(), leave=False, desc=\"transformer\"\n",
    "            )\n",
    "        }\n",
    "        for embedding_type in tqdm([\"absolute\", \"relative\"], leave=False, desc=\"embedding_type\")\n",
    "    }\n",
    "    for seed in tqdm(SEEDS, leave=False, desc=\"seed\")\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc23b282",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca059a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error\n",
    "import itertools\n",
    "\n",
    "numeric_results = {\n",
    "    \"seed\": [],\n",
    "    \"embed_type\": [],\n",
    "    \"train_model\": [],\n",
    "    \"test_model\": [],\n",
    "    \"precision\": [],\n",
    "    \"recall\": [],\n",
    "    \"fscore\": [],\n",
    "    \"stitched\": [],\n",
    "}\n",
    "for seed, embed_type2transformer2classifier in train_classifiers.items():\n",
    "    for embed_type, transformer2classifier in embed_type2transformer2classifier.items():\n",
    "        for (transformer_name1, classifier1), (transformer_name2, classifier2) in itertools.product(\n",
    "            transformer2classifier.items(), repeat=2\n",
    "        ):\n",
    "            if embed_type == \"absolute\" and (\n",
    "                transformer_name2hidden_dim[transformer_name1] != transformer_name2hidden_dim[transformer_name2]\n",
    "            ):\n",
    "                precision = recall = fscore = mae = np.nan\n",
    "            else:\n",
    "                test_latents = transformer2test_latents[transformer_name1][embed_type]\n",
    "                if latent_normalize:\n",
    "                    test_latents = F.normalize(test_latents, p=2, dim=-1)\n",
    "                preds = classifier2(test_latents)\n",
    "                test_y = np.array(test_dataset[target_key])\n",
    "\n",
    "                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average=\"weighted\")\n",
    "                mae = mean_absolute_error(y_true=test_y, y_pred=preds)\n",
    "            numeric_results[\"embed_type\"].append(embed_type)\n",
    "            numeric_results[\"train_model\"].append(transformer_name1)\n",
    "            numeric_results[\"test_model\"].append(transformer_name2)\n",
    "            numeric_results[\"precision\"].append(precision)\n",
    "            numeric_results[\"recall\"].append(recall)\n",
    "            numeric_results[\"fscore\"].append(fscore)\n",
    "            numeric_results[\"stitched\"].append(transformer_name1 != transformer_name2)\n",
    "            numeric_results[\"seed\"].append(seed)\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.options.display.max_columns = None\n",
    "pd.options.display.max_rows = None\n",
    "df = pd.DataFrame(numeric_results)\n",
    "df.to_csv(\n",
    "    f\"vision_transformer-stitching-{dataset_name}-{train_perc}.tsv\",\n",
    "    sep=\"\\t\",\n",
    ")\n",
    "\n",
    "df = df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "        \"train_model\",\n",
    "        \"test_model\",\n",
    "    ]\n",
    ").agg([np.mean])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "372200db",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_df = pd.read_csv(\n",
    "    f\"vision_transformer-stitching-{dataset_name}-{train_perc}.tsv\",\n",
    "    sep=\"\\t\",\n",
    "    index_col=0,\n",
    ")\n",
    "\n",
    "df = full_df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "        \"train_model\",\n",
    "        \"test_model\",\n",
    "    ]\n",
    ").agg([np.mean, \"count\"])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e2d398f",
   "metadata": {},
   "outputs": [],
   "source": [
    "f\"vision_transformer-stitching-{dataset_name}-{train_perc}.tsv\","
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d3e07f",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_df.drop(columns=[\"stitched\", \"seed\", \"precision\", \"recall\"]).groupby(\n",
    "    [\"embed_type\", \"train_model\", \"test_model\"]\n",
    ").agg([np.mean]).round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d301837a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# it_dataset = get_samples(lang=\"it\", sample_idxs=list(range(1000)))\n",
    "# it_transformer_name: str = \"dbmdz/bert-base-italian-cased\"\n",
    "# transformer, tokenizer = load_transformer(transformer_name=it_transformer_name)\n",
    "# it_anchor_latents = get_latents(\n",
    "#     dataloader=DataLoader(\n",
    "#         get_samples(\"it\", sample_idxs=anchor_idxs),\n",
    "#         num_workers=16,\n",
    "#         pin_memory=True,\n",
    "#         collate_fn=partial(collate_fn, tokenizer=tokenizer),\n",
    "#         batch_size=32,\n",
    "#     ),\n",
    "#     split=f\"{it_transformer_name}\",\n",
    "#     anchors=None,\n",
    "#     transformer=transformer,\n",
    "# )\n",
    "# it_latents = get_latents(\n",
    "#     dataloader=DataLoader(\n",
    "#         it_dataset,\n",
    "#         num_workers=16,\n",
    "#         pin_memory=True,\n",
    "#         collate_fn=partial(collate_fn, tokenizer=tokenizer),\n",
    "#         batch_size=32,\n",
    "#     ),\n",
    "#     split=f\"{it_transformer_name}\",\n",
    "#     anchors=it_anchor_latents[\"absolute\"].to(device),\n",
    "#     transformer=transformer,\n",
    "# )\n",
    "# subsample_anchors = it_latents[\"relative\"][:31, :]\n",
    "# for i_sample, sample in enumerate(it_samples):\n",
    "#     if sample[\"target\"] == 3:\n",
    "#         continue\n",
    "#     for embed_type in (\"relative\", \"absolute\"):\n",
    "#         latents = it_latents[embed_type]\n",
    "#         latents = torch.cat([latents[i_sample, :].unsqueeze(0), subsample_anchors], dim=0)\n",
    "#         classifier = train_classifiers[SEEDS[0]][embed_type][\"en\"]\n",
    "#         print(\n",
    "#             embed_type,\n",
    "#             classifier(latents)[0].item(),\n",
    "#             sample[\"class\"],\n",
    "#         )\n",
    "#     print()\n",
    "#     if i_sample > 100:\n",
    "#         break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
