{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acff9ee4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "from pathlib import Path\n",
    "from typing import List, Tuple\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bb6f7ed",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "import sklearn.pipeline\n",
    "import torch\n",
    "from nn_core.serialization import load_model, NNCheckpointIO\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "288e9290",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.data.text import TREC\n",
    "from rae.modules.attention import RelativeAttention, AttentionOutput\n",
    "from rae.pl_modules.pl_text_classifier import LightningTextClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afde0302",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "device: str = \"cuda\"\n",
    "\n",
    "fine_grained: bool = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b48e2eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, ClassLabel\n",
    "\n",
    "dataset_key: str = (\"trec\",)\n",
    "# dataset_key: str = (\"amazon_reviews_multi\", \"en\")\n",
    "datasets = load_dataset(*dataset_key)\n",
    "\n",
    "if dataset_key[0] == \"dbpedia_14\":\n",
    "\n",
    "    def clean_sample(example):\n",
    "        example[\"content\"] = example[\"content\"].strip('\"').strip()\n",
    "        return example\n",
    "\n",
    "    datasets = datasets.map(clean_sample)\n",
    "    target_key: str = \"label\"\n",
    "    data_key: str = \"content\"\n",
    "\n",
    "elif dataset_key[0] == \"trec\":\n",
    "    target_key: str = \"label-coarse\"\n",
    "    data_key: str = \"text\"\n",
    "\n",
    "elif dataset_key[0] == \"amazon_reviews_multi\":\n",
    "\n",
    "    def clean_sample(sample):\n",
    "        title: str = sample[\"review_title\"].strip('\"').strip(\".\").strip()\n",
    "        body: str = sample[\"review_body\"].strip('\"').strip(\".\").strip()\n",
    "\n",
    "        if body.lower().startswith(title.lower()):\n",
    "            title = \"\"\n",
    "\n",
    "        if len(title) > 0 and title[-1].isalpha():\n",
    "            title = f\"{title}.\"\n",
    "\n",
    "        sample[\"content\"] = f\"{title} {body}\".lstrip(\".\").strip()\n",
    "        if fine_grained:\n",
    "            sample[target_key] = str(sample[\"stars\"] - 1)\n",
    "        else:\n",
    "            sample[target_key] = sample[\"stars\"] > 3\n",
    "        return sample\n",
    "\n",
    "    target_key: str = \"stars\"\n",
    "    data_key: str = \"content\"\n",
    "    datasets = datasets.map(clean_sample)\n",
    "    datasets = datasets.cast_column(\n",
    "        target_key,\n",
    "        ClassLabel(num_classes=5 if fine_grained else 2, names=list(map(str, range(1, 6) if fine_grained else (0, 1)))),\n",
    "    )\n",
    "\n",
    "\n",
    "else:\n",
    "    assert False\n",
    "\n",
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8612d26f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = datasets[\"train\"]  # .select(range(1000))\n",
    "test_dataset = datasets[\"test\"]  # .select(range(1000))\n",
    "train_dataset, test_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02172ef1",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "class2idx = train_dataset.features[target_key].str2int\n",
    "train_dataset.features[target_key], class2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53eb22c3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def load_transformer(transformer_name):\n",
    "    transformer = AutoModel.from_pretrained(transformer_name, output_hidden_states=True, return_dict=True)\n",
    "    transformer.requires_grad_(False).eval()\n",
    "    return transformer, AutoTokenizer.from_pretrained(transformer_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3b59419",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "transformer_names: str = [\n",
    "    \"bert-base-cased\",\n",
    "    \"bert-base-uncased\",\n",
    "    \"google/electra-base-discriminator\",\n",
    "    \"roberta-base\",\n",
    "    # \"albert-base-v2\",\n",
    "    # \"distilbert-base-uncased\",\n",
    "    # \"distilbert-base-cased\",\n",
    "    \"xlm-roberta-base\",\n",
    "]\n",
    "\n",
    "transformers = {\n",
    "    transformer_name: load_transformer(transformer_name=transformer_name)\n",
    "    for transformer_name in transformer_names  # all these have latents already cached in latents.pt\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c280da62",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.pipeline import make_pipeline\n",
    "from sklearn.preprocessing import StandardScaler, Normalizer\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb7f7213",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "train_y = np.array(train_dataset[target_key])\n",
    "test_y = np.array(test_dataset[target_key])\n",
    "len(set(train_y)), len(set(test_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82c9d35e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def call_transformer(batch, transformer):\n",
    "    encoding = batch[\"encoding\"].to(device)\n",
    "    sample_encodings = transformer(**encoding)[\"hidden_states\"][-1]\n",
    "    # TODO: aggregation mode\n",
    "    # result = []\n",
    "    # for sample_encoding, sample_mask in zip(sample_encodings, batch[\"mask\"]):\n",
    "    #     result.append(sample_encoding[sample_mask].mean(dim=0))\n",
    "\n",
    "    # return torch.stack(result, dim=0)\n",
    "    return sample_encodings[:, 0, :]  # CLS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45bd0fbc",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from rae.data.text.datamodule import AnchorsMode\n",
    "from sklearn.utils import shuffle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad42de90",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from typing import *\n",
    "\n",
    "\n",
    "def get_anchors(dataset, anchors_mode, anchors_num) -> Dict[str, Any]:\n",
    "    dataset_to_consider = dataset\n",
    "\n",
    "    if anchors_mode == AnchorsMode.DATASET:\n",
    "        return {\n",
    "            \"anchor_idxs\": list(range(len(dataset_to_consider))),\n",
    "            \"anchor_samples\": list(dataset_to_consider),\n",
    "            \"anchor_targets\": dataset_to_consider[target_key],\n",
    "            \"anchor_classes\": dataset_to_consider.classes,\n",
    "            \"anchor_latents\": None,\n",
    "        }\n",
    "    elif anchors_mode == AnchorsMode.STRATIFIED_SUBSET:\n",
    "        shuffled_idxs, shuffled_targets = shuffle(\n",
    "            np.asarray(list(range(len(dataset_to_consider)))),\n",
    "            np.asarray(dataset_to_consider[target_key]),\n",
    "            random_state=0,\n",
    "        )\n",
    "        all_targets = sorted(set(shuffled_targets))\n",
    "        class2idxs = {target: shuffled_idxs[shuffled_targets == target] for target in all_targets}\n",
    "\n",
    "        anchor_indices = []\n",
    "        i = 0\n",
    "        while len(anchor_indices) < anchors_num:\n",
    "            for target, target_idxs in class2idxs.items():\n",
    "                if i < len(target_idxs):\n",
    "                    anchor_indices.append(target_idxs[i])\n",
    "                if len(anchor_indices) == anchors_num:\n",
    "                    break\n",
    "            i += 1\n",
    "\n",
    "        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]\n",
    "\n",
    "        return {\n",
    "            \"anchor_idxs\": anchor_indices,\n",
    "            \"anchor_samples\": anchors,\n",
    "            \"anchor_targets\": [anchor[target_key] for anchor in anchors],\n",
    "            \"anchor_classes\": [\n",
    "                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors\n",
    "            ],\n",
    "            \"anchor_latents\": None,\n",
    "        }\n",
    "    elif anchors_mode == AnchorsMode.STRATIFIED:\n",
    "        if anchors_num >= len(dataset_to_consider.classes):\n",
    "            _, anchor_indices = train_test_split(\n",
    "                list(range(len(dataset_to_consider))),\n",
    "                test_size=anchors_num,\n",
    "                stratify=dataset_to_consider[target_key] if anchors_num >= len(dataset_to_consider.classes) else None,\n",
    "                random_state=0,\n",
    "            )\n",
    "        else:\n",
    "            anchor_indices = HARDCODED_ANCHORS[:anchors_num]\n",
    "        anchors = [dataset_to_consider[int(idx)] for idx in anchor_indices]\n",
    "        return {\n",
    "            \"anchor_idxs\": anchor_indices,\n",
    "            \"anchor_samples\": anchors,\n",
    "            \"anchor_targets\": [anchor[target_key] for anchor in anchors],\n",
    "            \"anchor_classes\": [\n",
    "                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors\n",
    "            ],\n",
    "            \"anchor_latents\": None,\n",
    "        }\n",
    "    elif anchors_mode == AnchorsMode.RANDOM_SAMPLES:\n",
    "        anchor_idxs = list(range(len(dataset_to_consider)))\n",
    "        random.shuffle(anchor_idxs)\n",
    "        anchors = [dataset_to_consider[index] for index in anchor_idxs]\n",
    "        return {\n",
    "            \"anchor_idxs\": anchor_idxs,\n",
    "            \"anchor_samples\": anchors,\n",
    "            \"anchor_targets\": [anchor[target_key] for anchor in anchors],\n",
    "            \"anchor_classes\": [\n",
    "                dataset_to_consider.features[target_key].int2str(anchor[target_key]) for anchor in anchors\n",
    "            ],\n",
    "            \"anchor_latents\": None,\n",
    "        }\n",
    "    elif anchors_mode == AnchorsMode.RANDOM_LATENTS:\n",
    "        raise NotImplementedError\n",
    "    else:\n",
    "        raise RuntimeError()\n",
    "\n",
    "\n",
    "anchors_num: int = 768\n",
    "anchor_idxs = get_anchors(train_dataset, anchors_mode=AnchorsMode.STRATIFIED_SUBSET, anchors_num=anchors_num)[\n",
    "    \"anchor_idxs\"\n",
    "]\n",
    "anchor_idxs = [int(x) for x in anchor_idxs]\n",
    "anchors = [train_dataset[anchor_idx] for anchor_idx in anchor_idxs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65f8bc98",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "relative_projection = RelativeAttention(\n",
    "    n_anchors=anchors_num,\n",
    "    normalization_mode=\"l2\",\n",
    "    similarity_mode=\"inner\",\n",
    "    values_mode=\"similarities\",\n",
    "    n_classes=train_dataset.features[target_key].num_classes,\n",
    "    output_normalization_mode=None,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54fe4d83",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def collate_fn(batch, tokenizer):\n",
    "    encoding = tokenizer(\n",
    "        [sample[data_key] for sample in batch],\n",
    "        return_tensors=\"pt\",\n",
    "        return_special_tokens_mask=True,\n",
    "        truncation=True,\n",
    "        padding=True,\n",
    "    )\n",
    "    # mask = encoding[\"attention_mask\"] * encoding[\"special_tokens_mask\"].bool().logical_not()\n",
    "    del encoding[\"special_tokens_mask\"]\n",
    "    return {\"encoding\": encoding}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31a47936",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:\n",
    "    absolute_latents: List = []\n",
    "    relative_latents: List = []\n",
    "\n",
    "    transformer = transformer.to(device)\n",
    "    for batch in tqdm(dataloader, desc=f\"[{split}] Computing latents\"):\n",
    "        with torch.no_grad():\n",
    "            batch_latents = call_transformer(batch=batch, transformer=transformer)\n",
    "\n",
    "            absolute_latents.append(batch_latents.cpu())\n",
    "\n",
    "            if anchors is not None:\n",
    "                batch_rel_latents = relative_projection.encode(x=batch_latents, anchors=anchors)[\n",
    "                    AttentionOutput.SIMILARITIES\n",
    "                ]\n",
    "                relative_latents.append(batch_rel_latents.cpu())\n",
    "\n",
    "    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0)\n",
    "    relative_latents: torch.Tensor = (\n",
    "        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents\n",
    "    )\n",
    "\n",
    "    transformer = transformer.cpu()\n",
    "    return {\n",
    "        \"absolute\": absolute_latents,\n",
    "        \"relative\": relative_latents,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4639ccb6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1bceca1e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae import PROJECT_ROOT\n",
    "\n",
    "LATENTS_DIR: Path = PROJECT_ROOT / \"data\" / \"latents\" / \"/\".join(dataset_key)\n",
    "LATENTS_DIR.mkdir(exist_ok=True, parents=True)\n",
    "LATENTS_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bc72d24",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_latents():\n",
    "    latents = {}\n",
    "\n",
    "    for transformer_name in transformers.keys():\n",
    "        transformer_path = LATENTS_DIR / f\"{transformer_name.replace('/', '-')}.pt\"\n",
    "        if transformer_path.exists():\n",
    "            latents[transformer_name] = torch.load(transformer_path)\n",
    "\n",
    "    return latents\n",
    "\n",
    "\n",
    "list(load_latents().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84af1d42",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "FORCE_RECOMPUTE: bool = False\n",
    "CACHE_LATENTS: bool = True\n",
    "\n",
    "latents = load_latents()\n",
    "\n",
    "missing_transformers = (\n",
    "    transformer_names if FORCE_RECOMPUTE else [t_name for t_name in transformer_names if t_name not in latents]\n",
    ")\n",
    "\n",
    "for transformer_name in missing_transformers:\n",
    "    latents[transformer_name] = {\n",
    "        \"anchors_latents\": (\n",
    "            anchors_latents := get_latents(\n",
    "                dataloader=DataLoader(\n",
    "                    anchors,\n",
    "                    num_workers=8,\n",
    "                    pin_memory=True,\n",
    "                    collate_fn=partial(collate_fn, tokenizer=transformers[transformer_name][1]),\n",
    "                    batch_size=32,\n",
    "                ),\n",
    "                split=f\"{transformer_name}, anchor\",\n",
    "                anchors=None,\n",
    "                transformer=transformers[transformer_name][0],\n",
    "            )[\"absolute\"]\n",
    "        ),\n",
    "        **{\n",
    "            str(dataset_split.split): get_latents(\n",
    "                dataloader=DataLoader(\n",
    "                    dataset_split,\n",
    "                    num_workers=8,\n",
    "                    pin_memory=True,\n",
    "                    collate_fn=partial(collate_fn, tokenizer=transformers[transformer_name][1]),\n",
    "                    batch_size=32,\n",
    "                ),\n",
    "                split=f\"{transformer_name}, {str(dataset_split.split)}\",\n",
    "                anchors=anchors_latents.to(device),\n",
    "                transformer=transformers[transformer_name][0],\n",
    "            )\n",
    "            for dataset_split in [train_dataset, test_dataset]\n",
    "        },\n",
    "    }\n",
    "    # Save latents\n",
    "    if CACHE_LATENTS:\n",
    "        transformer_path = LATENTS_DIR / f\"{transformer_name.replace('/', '-')}.pt\"\n",
    "        torch.save(latents[transformer_name], transformer_path)\n",
    "latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4292d735",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import Adam\n",
    "\n",
    "\n",
    "# def fit(X, y):\n",
    "#     classifier = make_pipeline(\n",
    "#         Normalizer(), StandardScaler(), SVC(gamma=\"auto\", kernel=\"linear\", random_state=42)\n",
    "#     )  # , class_weight=\"balanced\"))\n",
    "#     classifier.fit(X, y)\n",
    "#     return lambda x: classifier.predict(x)\n",
    "\n",
    "\n",
    "from torch import nn\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "\n",
    "class Lambda(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super().__init__()\n",
    "        self.func = func\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.func(x)\n",
    "\n",
    "\n",
    "DATASET2LR = {\n",
    "    \"trec\": 1e-4,\n",
    "    \"amazon_reviews_multi\": 1e-3,\n",
    "    \"dbpedia_14\": 1e-4,\n",
    "}\n",
    "\n",
    "\n",
    "def fit(X, y, seed):\n",
    "    seed_everything(seed)\n",
    "    dataset = TensorDataset(X, torch.as_tensor(y))\n",
    "    loader = DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)\n",
    "\n",
    "    model = nn.Sequential(\n",
    "        nn.LayerNorm(normalized_shape=anchors_num),\n",
    "        nn.Linear(in_features=anchors_num, out_features=anchors_num),\n",
    "        nn.SiLU(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=anchors_num),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=anchors_num, out_features=anchors_num),\n",
    "        nn.SiLU(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=anchors_num),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=anchors_num, out_features=train_dataset.features[target_key].num_classes),\n",
    "        nn.ReLU(),\n",
    "    ).to(device)\n",
    "    opt = Adam(model.parameters(), lr=DATASET2LR[dataset_key[0]])\n",
    "    loss_fn = CrossEntropyLoss()\n",
    "    for epoch in tqdm(range(5 if fine_grained else 3), leave=False, desc=\"epoch\"):\n",
    "        for batch_x, batch_y in loader:\n",
    "            batch_x = batch_x.to(device)\n",
    "            batch_y = batch_y.to(device)\n",
    "            pred_y = model(batch_x)\n",
    "            loss = loss_fn(pred_y, batch_y)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            opt.zero_grad()\n",
    "    model = model.cpu().eval()\n",
    "    return lambda x: model(x).argmax(-1).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7951c369",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "SEEDS = [0, 1, 2, 3, 4]\n",
    "\n",
    "\n",
    "fitted_classifiers = {\n",
    "    seed: {\n",
    "        transformer_name: {\n",
    "            embedding_type: fit(latents[transformer_name][\"train\"][embedding_type], train_y, seed)\n",
    "            for embedding_type in tqdm([\"absolute\", \"relative\"], leave=False, desc=\"embedding_type\")\n",
    "        }\n",
    "        for transformer_name in tqdm(transformers, desc=\"transformer\")\n",
    "    }\n",
    "    for seed in SEEDS\n",
    "}\n",
    "fitted_classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca059a8c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "\n",
    "numeric_results = {\n",
    "    \"seed\": [],\n",
    "    \"embed_type\": [],\n",
    "    \"embed_transformer\": [],\n",
    "    \"classifier_transformer\": [],\n",
    "    \"precision\": [],\n",
    "    \"recall\": [],\n",
    "    \"fscore\": [],\n",
    "    \"stitched\": [],\n",
    "}\n",
    "for seed in SEEDS:\n",
    "    for embed_type in [\"absolute\", \"relative\"]:\n",
    "        for embed_transformer in transformers:\n",
    "            for classifier_transformer in transformers:\n",
    "                test_latents = latents[embed_transformer][\"test\"][embed_type]\n",
    "                classifier = fitted_classifiers[seed][classifier_transformer][embed_type]\n",
    "                preds = classifier(test_latents)\n",
    "\n",
    "                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average=\"weighted\")\n",
    "                numeric_results[\"embed_type\"].append(embed_type)\n",
    "                numeric_results[\"embed_transformer\"].append(embed_transformer)\n",
    "                numeric_results[\"classifier_transformer\"].append(classifier_transformer)\n",
    "                numeric_results[\"precision\"].append(precision)\n",
    "                numeric_results[\"recall\"].append(recall)\n",
    "                numeric_results[\"fscore\"].append(fscore)\n",
    "                numeric_results[\"stitched\"].append(embed_transformer != classifier_transformer)\n",
    "                numeric_results[\"seed\"].append(seed)\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.options.display.max_columns = None\n",
    "pd.options.display.max_rows = None\n",
    "df = pd.DataFrame(numeric_results)\n",
    "dataset_name = \"_\".join(dataset_key)\n",
    "fine_grained_str = \"_fine_grained\" if fine_grained else \"_coarse\"\n",
    "df.to_csv(\n",
    "    f\"nlp_stitching-{dataset_name}{'' if dataset_key[0] != 'amazon_reviews_multi' else fine_grained_str}.tsv\", sep=\"\\t\"\n",
    ")\n",
    "\n",
    "df = df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "        \"embed_transformer\",\n",
    "        \"classifier_transformer\",\n",
    "    ]\n",
    ").agg([np.mean, \"count\"])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e09a28",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62fd6f68",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d3e07f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "    ]\n",
    ").agg([np.mean, \"count\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f130be91",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
