{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3afba055",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e2fde5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from pathlib import Path\n",
    "from rae import PROJECT_ROOT\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "from pytorch_lightning import seed_everything\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bddf3583",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModel, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a60d9f72",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.attention import RelativeAttention, AttentionOutput"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f2829d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device: str = \"cuda\"\n",
    "fine_grained: bool = True\n",
    "target_key: str = \"class\"\n",
    "data_key: str = \"content\"\n",
    "anchor_dataset_name: str = \"amazon_translated\"  # wikimatrix, amazon_translated\n",
    "ALL_LANGS = (\"en\", \"es\", \"fr\", \"ja\")\n",
    "num_anchors: int = 768\n",
    "train_perc: float = 0.25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54fdaa84",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, ClassLabel\n",
    "\n",
    "\n",
    "def get_dataset(lang: str, split: str, perc: float, fine_grained: bool):\n",
    "    seed_everything(42)\n",
    "    assert 0 < perc <= 1\n",
    "    dataset = load_dataset(\"amazon_reviews_multi\", lang)[split]\n",
    "\n",
    "    if not fine_grained:\n",
    "        dataset = dataset.filter(lambda sample: sample[\"stars\"] != 3)\n",
    "\n",
    "    # Select a random subset\n",
    "    indices = list(range(len(dataset)))\n",
    "    random.shuffle(indices)\n",
    "    indices = indices[: int(len(indices) * perc)]\n",
    "    dataset = dataset.select(indices)\n",
    "\n",
    "    def clean_sample(sample):\n",
    "        title: str = sample[\"review_title\"].strip('\"').strip(\".\").strip()\n",
    "        body: str = sample[\"review_body\"].strip('\"').strip(\".\").strip()\n",
    "\n",
    "        if body.lower().startswith(title.lower()):\n",
    "            title = \"\"\n",
    "\n",
    "        if len(title) > 0 and title[-1].isalpha():\n",
    "            title = f\"{title}.\"\n",
    "\n",
    "        sample[\"content\"] = f\"{title} {body}\".lstrip(\".\").strip()\n",
    "        if fine_grained:\n",
    "            sample[target_key] = str(sample[\"stars\"] - 1)\n",
    "        else:\n",
    "            sample[target_key] = sample[\"stars\"] > 3\n",
    "        return sample\n",
    "\n",
    "    dataset = dataset.map(clean_sample)\n",
    "    dataset = dataset.cast_column(\n",
    "        target_key,\n",
    "        ClassLabel(num_classes=5 if fine_grained else 2, names=list(map(str, range(1, 6) if fine_grained else (0, 1)))),\n",
    "    )\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49e4af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_datasets = {\n",
    "    lang: get_dataset(lang=lang, split=\"train\", perc=train_perc, fine_grained=fine_grained) for lang in ALL_LANGS\n",
    "}\n",
    "train_datasets[\"en\"].features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a41a58f",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert len(set(frozenset(train_dataset.features.keys()) for train_dataset in train_datasets.values())) == 1\n",
    "class2idx = train_datasets[\"en\"].features[target_key].str2int\n",
    "train_datasets[\"en\"].features[target_key], class2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981b6fa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_transformer(transformer_name):\n",
    "    transformer = AutoModel.from_pretrained(transformer_name, output_hidden_states=True, return_dict=True)\n",
    "    transformer.requires_grad_(False).eval()\n",
    "    return transformer, AutoTokenizer.from_pretrained(transformer_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1a0b9f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_datasets = {lang: get_dataset(lang=lang, split=\"test\", perc=1, fine_grained=fine_grained) for lang in ALL_LANGS}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6beba308",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def call_transformer(batch, transformer):\n",
    "    encoding = batch[\"encoding\"].to(device)\n",
    "    sample_encodings = transformer(**encoding)[\"hidden_states\"][-1]\n",
    "    # TODO: aggregation mode\n",
    "    # result = []\n",
    "    # for sample_encoding, sample_mask in zip(sample_encodings, batch[\"mask\"]):\n",
    "    #     result.append(sample_encoding[sample_mask].mean(dim=0))\n",
    "\n",
    "    # return torch.stack(result, dim=0)\n",
    "    return sample_encodings[:, 0, :]  # CLS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b88e22c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3649cedf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.data.text import WikiMatrixAnchors, MultilingualAmazonAnchors\n",
    "from typing import *\n",
    "\n",
    "anchor_dataset2num_samples = {\"wikimatrix\": 3338, \"amazon_translated\": 1000}\n",
    "anchor_dataset2first_anchors = {\n",
    "    \"wikimatrix\": [\n",
    "        361,\n",
    "        2192,\n",
    "        1855,\n",
    "        1163,\n",
    "        1434,\n",
    "        3065,\n",
    "        1329,\n",
    "        2381,\n",
    "        2366,\n",
    "        466,\n",
    "        1488,\n",
    "        3007,\n",
    "        1749,\n",
    "        2332,\n",
    "        2463,\n",
    "        2180,\n",
    "        1790,\n",
    "        3328,\n",
    "        2865,\n",
    "        1457,\n",
    "    ],\n",
    "    \"amazon_translated\": [\n",
    "        776,\n",
    "        507,\n",
    "        895,\n",
    "        922,\n",
    "        33,\n",
    "        483,\n",
    "        85,\n",
    "        750,\n",
    "        354,\n",
    "        523,\n",
    "        184,\n",
    "        809,\n",
    "        418,\n",
    "        615,\n",
    "        682,\n",
    "        501,\n",
    "        760,\n",
    "        49,\n",
    "        732,\n",
    "        336,\n",
    "    ],\n",
    "}\n",
    "\n",
    "\n",
    "def _amazon_translated_get_samples(lang: str, sample_idxs) -> Sequence:\n",
    "    anchor_dataset = MultilingualAmazonAnchors(split=\"train\", language=lang)\n",
    "    anchors = []\n",
    "    for anchor_idx in sample_idxs:\n",
    "        anchor = anchor_dataset[anchor_idx]\n",
    "        anchor[data_key] = anchor[\"data\"]\n",
    "        anchors.append(anchor)\n",
    "    return anchors\n",
    "\n",
    "\n",
    "def _wikimatrix_get_samples(lang: str, sample_idxs) -> Sequence:\n",
    "    anchor_dataset = WikiMatrixAnchors(\n",
    "        split=\"train\", language=lang, lang2threshold={\"es\": 1.06, \"fr\": 1.06, \"ja\": 1.06}, path=PROJECT_ROOT / \"data\"\n",
    "    )\n",
    "    anchors = []\n",
    "    for anchor_idx in sample_idxs:\n",
    "        anchor = anchor_dataset[anchor_idx]\n",
    "        anchor[data_key] = anchor[\"data\"]\n",
    "        anchors.append(anchor)\n",
    "    return anchors\n",
    "\n",
    "\n",
    "anchor_dataset2sampling = {\"wikimatrix\": _wikimatrix_get_samples, \"amazon_translated\": _amazon_translated_get_samples}\n",
    "\n",
    "assert num_anchors <= anchor_dataset2num_samples[anchor_dataset_name]\n",
    "\n",
    "seed_everything(42)\n",
    "anchor_idxs = list(range(anchor_dataset2num_samples[anchor_dataset_name]))\n",
    "random.shuffle(anchor_idxs)\n",
    "anchor_idxs = anchor_idxs[:num_anchors]\n",
    "\n",
    "assert anchor_idxs[:20] == anchor_dataset2first_anchors[anchor_dataset_name]  # better safe than sorry\n",
    "lang2anchors: Mapping[str, Sequence] = {\n",
    "    lang: anchor_dataset2sampling[anchor_dataset_name](lang=lang, sample_idxs=anchor_idxs) for lang in ALL_LANGS\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f4b15a",
   "metadata": {},
   "outputs": [],
   "source": [
    "lang2transformer_name = {\n",
    "    \"en\": \"roberta-base\",\n",
    "    \"es\": \"PlanTL-GOB-ES/roberta-base-bne\",\n",
    "    \"fr\": \"ClassCat/roberta-base-french\",\n",
    "    \"ja\": \"nlp-waseda/roberta-base-japanese\",\n",
    "}\n",
    "assert set(lang2transformer_name.keys()) == set(ALL_LANGS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290719be",
   "metadata": {},
   "outputs": [],
   "source": [
    "relative_projection = RelativeAttention(\n",
    "    n_anchors=num_anchors,\n",
    "    normalization_mode=\"l2\",\n",
    "    similarity_mode=\"inner\",\n",
    "    values_mode=\"similarities\",\n",
    "    n_classes=train_datasets[\"en\"].features[target_key].num_classes,\n",
    "    output_normalization_mode=None,\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "397f6d06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(batch, tokenizer):\n",
    "    encoding = tokenizer(\n",
    "        [sample[data_key] for sample in batch],\n",
    "        return_tensors=\"pt\",\n",
    "        return_special_tokens_mask=True,\n",
    "        truncation=True,\n",
    "        max_length=512,\n",
    "        padding=True,\n",
    "    )\n",
    "    # mask = encoding[\"attention_mask\"] * encoding[\"special_tokens_mask\"].bool().logical_not()\n",
    "    del encoding[\"special_tokens_mask\"]\n",
    "    # return {\"encoding\": encoding, \"mask\": mask.bool()}\n",
    "    return {\"encoding\": encoding}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a3d7735",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:\n",
    "    absolute_latents: List = []\n",
    "    relative_latents: List = []\n",
    "\n",
    "    transformer = transformer.to(device)\n",
    "    for batch in tqdm(dataloader, desc=f\"[{split}] Computing latents\"):\n",
    "        with torch.no_grad():\n",
    "            batch_latents = call_transformer(batch=batch, transformer=transformer)\n",
    "\n",
    "            absolute_latents.append(batch_latents.cpu())\n",
    "\n",
    "            if anchors is not None:\n",
    "                batch_rel_latents = relative_projection.encode(x=batch_latents, anchors=anchors)[\n",
    "                    AttentionOutput.SIMILARITIES\n",
    "                ]\n",
    "                relative_latents.append(batch_rel_latents.cpu())\n",
    "\n",
    "    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0)\n",
    "    relative_latents: torch.Tensor = (\n",
    "        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents\n",
    "    )\n",
    "\n",
    "    transformer = transformer.cpu()\n",
    "    return {\n",
    "        \"absolute\": absolute_latents,\n",
    "        \"relative\": relative_latents,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a2122733",
   "metadata": {},
   "outputs": [],
   "source": [
    "anchor_dataset_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f89a145",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae import PROJECT_ROOT\n",
    "\n",
    "LATENTS_DIR: Path = (\n",
    "    PROJECT_ROOT\n",
    "    / \"data\"\n",
    "    / \"latents\"\n",
    "    / \"multilingual_amazon\"\n",
    "    / str(train_perc)\n",
    "    / anchor_dataset_name\n",
    "    / (\"fine_grained\" if fine_grained else \"coarse_grained\")\n",
    ")\n",
    "LATENTS_DIR.mkdir(exist_ok=True, parents=True)\n",
    "LATENTS_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f5398a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_latents(split: str, langs: Sequence[str]):\n",
    "    lang2latents = {}\n",
    "\n",
    "    for lang in langs:\n",
    "        transformer_name = lang2transformer_name[lang]\n",
    "        transformer_path = LATENTS_DIR / split / lang / f\"{transformer_name.replace('/', '-')}.pt\"\n",
    "        if transformer_path.exists():\n",
    "            lang2latents[lang] = torch.load(transformer_path)\n",
    "\n",
    "    return lang2latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d29b727",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "\n",
    "\n",
    "def encode_latents(langs, lang2dataset, lang2latents, split: str):\n",
    "    for lang in langs:\n",
    "        transformer_name: str = lang2transformer_name[lang]\n",
    "        lang_transformer, lang_tokenizer = load_transformer(transformer_name=transformer_name)\n",
    "        lang2latents[lang] = {\n",
    "            \"anchors_latents\": (\n",
    "                anchors_latents := get_latents(\n",
    "                    dataloader=DataLoader(\n",
    "                        lang2anchors[lang],\n",
    "                        num_workers=4,\n",
    "                        pin_memory=True,\n",
    "                        collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),\n",
    "                        batch_size=32,\n",
    "                    ),\n",
    "                    split=f\"{transformer_name}, anchor, {split}\",\n",
    "                    anchors=None,\n",
    "                    transformer=lang_transformer,\n",
    "                )[\"absolute\"]\n",
    "            ),\n",
    "            **get_latents(\n",
    "                dataloader=DataLoader(\n",
    "                    lang2dataset[lang],\n",
    "                    num_workers=4,\n",
    "                    pin_memory=True,\n",
    "                    collate_fn=partial(collate_fn, tokenizer=lang_tokenizer),\n",
    "                    batch_size=32,\n",
    "                ),\n",
    "                split=f\"{split}/{lang}\",\n",
    "                anchors=anchors_latents.to(device),\n",
    "                transformer=lang_transformer,\n",
    "            ),\n",
    "        }\n",
    "        # Save latents\n",
    "        if CACHE_LATENTS:\n",
    "            transformer_path = LATENTS_DIR / split / lang / f\"{transformer_name.replace('/', '-')}.pt\"\n",
    "            transformer_path.parent.mkdir(exist_ok=True, parents=True)\n",
    "            torch.save(lang2latents[lang], transformer_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b41b31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute test latents\n",
    "\n",
    "FORCE_RECOMPUTE: bool = False\n",
    "CACHE_LATENTS: bool = True\n",
    "\n",
    "langt2test_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(split=\"test\", langs=ALL_LANGS)\n",
    "missing_langs = ALL_LANGS if FORCE_RECOMPUTE else [lang for lang in ALL_LANGS if lang not in langt2test_latents]\n",
    "encode_latents(langs=missing_langs, lang2dataset=test_datasets, lang2latents=langt2test_latents, split=\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74cb3ae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute train latents\n",
    "\n",
    "FORCE_RECOMPUTE: bool = False\n",
    "CACHE_LATENTS: bool = True\n",
    "\n",
    "lang2train_latents: Dict[str, Mapping[str, torch.Tensor]] = load_latents(split=\"train\", langs=train_datasets.keys())\n",
    "missing_langs = (\n",
    "    train_datasets.keys()\n",
    "    if FORCE_RECOMPUTE\n",
    "    else [lang for lang in train_datasets.keys() if lang not in lang2train_latents]\n",
    ")\n",
    "encode_latents(langs=missing_langs, lang2dataset=train_datasets, lang2latents=lang2train_latents, split=\"train\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd41a8b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_normalize: bool = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943bb3a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn import CrossEntropyLoss\n",
    "from torch.optim import Adam\n",
    "\n",
    "\n",
    "# def fit(X, y, seed, **kwargs):\n",
    "#     classifier = make_pipeline(\n",
    "#         Normalizer(), StandardScaler(), SVC(gamma=\"auto\", kernel=\"linear\", max_iter=200, random_state=seed)\n",
    "#     )  # , class_weight=\"balanced\"))\n",
    "#     classifier.fit(X, y)\n",
    "#     return lambda x: classifier.predict(x)\n",
    "\n",
    "\n",
    "class Lambda(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super().__init__()\n",
    "        self.func = func\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.func(x)\n",
    "\n",
    "\n",
    "def fit(X: torch.Tensor, y, seed, normalize: bool):\n",
    "    seed_everything(seed)\n",
    "    if normalize:\n",
    "        X = F.normalize(X, p=2, dim=-1)\n",
    "    dataset = TensorDataset(X, torch.as_tensor(y))\n",
    "    loader = DataLoader(dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)\n",
    "\n",
    "    model = nn.Sequential(\n",
    "        nn.LayerNorm(normalized_shape=num_anchors),\n",
    "        nn.Linear(in_features=num_anchors, out_features=num_anchors),\n",
    "        nn.SiLU(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=num_anchors),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=num_anchors, out_features=num_anchors),\n",
    "        nn.SiLU(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=num_anchors),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(\n",
    "            in_features=num_anchors, out_features=list(train_datasets.values())[0].features[target_key].num_classes\n",
    "        ),\n",
    "        nn.ReLU(),\n",
    "    ).to(device)\n",
    "    opt = Adam(model.parameters(), lr=1e-3)\n",
    "    loss_fn = CrossEntropyLoss()\n",
    "    for epoch in tqdm(range(5 if fine_grained else 3), leave=False, desc=\"epoch\"):\n",
    "        for batch_x, batch_y in loader:\n",
    "            batch_x = batch_x.to(device)\n",
    "            batch_y = batch_y.to(device)\n",
    "            pred_y = model(batch_x)\n",
    "            loss = loss_fn(pred_y, batch_y)\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            opt.zero_grad()\n",
    "    model = model.cpu().eval()\n",
    "    return lambda x: model(x).argmax(-1).detach().cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7be03123",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7951c369",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEEDS = list(range(5))\n",
    "train_classifiers = {\n",
    "    seed: {\n",
    "        embedding_type: {\n",
    "            train_lang: fit(\n",
    "                lang2train_latents[train_lang][embedding_type],\n",
    "                train_dataset[target_key],\n",
    "                seed=seed,\n",
    "                normalize=latent_normalize,\n",
    "            )\n",
    "            for train_lang, train_dataset in tqdm(train_datasets.items(), leave=False, desc=\"lang\")\n",
    "        }\n",
    "        for embedding_type in tqdm([\"absolute\", \"relative\"], leave=False, desc=\"embedding_type\")\n",
    "    }\n",
    "    for seed in tqdm(SEEDS, leave=False, desc=\"seed\")\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc23b282",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca059a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error\n",
    "\n",
    "numeric_results = {\n",
    "    \"seed\": [],\n",
    "    \"embed_type\": [],\n",
    "    \"train_lang\": [],\n",
    "    \"test_lang\": [],\n",
    "    \"precision\": [],\n",
    "    \"recall\": [],\n",
    "    \"fscore\": [],\n",
    "    \"mae\": [],\n",
    "    \"stitched\": [],\n",
    "}\n",
    "for seed, embed_type2train_lang2classifier in train_classifiers.items():\n",
    "    for embed_type, train_lang2classifier in embed_type2train_lang2classifier.items():\n",
    "        for train_lang, classifier in train_lang2classifier.items():\n",
    "            for test_lang, test_latents in langt2test_latents.items():\n",
    "                test_latents = test_latents[embed_type]\n",
    "                if latent_normalize:\n",
    "                    test_latents = F.normalize(test_latents, p=2, dim=-1)\n",
    "                preds = classifier(test_latents)\n",
    "                test_y = np.array(test_datasets[test_lang][target_key])\n",
    "\n",
    "                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average=\"weighted\")\n",
    "                mae = mean_absolute_error(y_true=test_y, y_pred=preds)\n",
    "                numeric_results[\"embed_type\"].append(embed_type)\n",
    "                numeric_results[\"train_lang\"].append(train_lang)\n",
    "                numeric_results[\"test_lang\"].append(test_lang)\n",
    "                numeric_results[\"precision\"].append(precision)\n",
    "                numeric_results[\"recall\"].append(recall)\n",
    "                numeric_results[\"fscore\"].append(fscore)\n",
    "                numeric_results[\"stitched\"].append(train_lang != test_lang)\n",
    "                numeric_results[\"mae\"].append(mae)\n",
    "                numeric_results[\"seed\"].append(seed)\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.options.display.max_columns = None\n",
    "pd.options.display.max_rows = None\n",
    "df = pd.DataFrame(numeric_results)\n",
    "df.to_csv(\n",
    "    f\"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv\",\n",
    "    sep=\"\\t\",\n",
    ")\n",
    "\n",
    "df = df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "        \"train_lang\",\n",
    "        \"test_lang\",\n",
    "    ]\n",
    ").agg([np.mean])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a6c106",
   "metadata": {},
   "outputs": [],
   "source": [
    "f\"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "372200db",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# fine_grained: bool = False\n",
    "# anchor_dataset_name: str = \"amazon_translated\" # wikimatrix, amazon_translated\n",
    "# train_perc: float = 0.25\n",
    "\n",
    "# full_df = pd.read_csv(\n",
    "#     f\"nlp_multilingual-stitching-amazon-{'fine_grained' if fine_grained else 'coarse_grained'}-{anchor_dataset_name}-{train_perc}.tsv\",\n",
    "#     sep=\"\\t\",\n",
    "#     index_col=0,\n",
    "# )\n",
    "\n",
    "df = df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "        \"train_lang\",\n",
    "        \"test_lang\",\n",
    "    ]\n",
    ").agg([np.mean, \"count\"])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d3e07f",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.drop(columns=[\"stitched\", \"seed\", \"precision\", \"recall\"])[full_df.train_lang == \"en\"].groupby(\n",
    "    [\"embed_type\", \"train_lang\", \"test_lang\"]\n",
    ").agg([np.mean, np.std]).round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d301837a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# it_dataset = get_samples(lang=\"it\", sample_idxs=list(range(1000)))\n",
    "# it_transformer_name: str = \"dbmdz/bert-base-italian-cased\"\n",
    "# transformer, tokenizer = load_transformer(transformer_name=it_transformer_name)\n",
    "# it_anchor_latents = get_latents(\n",
    "#     dataloader=DataLoader(\n",
    "#         get_samples(\"it\", sample_idxs=anchor_idxs),\n",
    "#         num_workers=16,\n",
    "#         pin_memory=True,\n",
    "#         collate_fn=partial(collate_fn, tokenizer=tokenizer),\n",
    "#         batch_size=32,\n",
    "#     ),\n",
    "#     split=f\"{it_transformer_name}\",\n",
    "#     anchors=None,\n",
    "#     transformer=transformer,\n",
    "# )\n",
    "# it_latents = get_latents(\n",
    "#     dataloader=DataLoader(\n",
    "#         it_dataset,\n",
    "#         num_workers=16,\n",
    "#         pin_memory=True,\n",
    "#         collate_fn=partial(collate_fn, tokenizer=tokenizer),\n",
    "#         batch_size=32,\n",
    "#     ),\n",
    "#     split=f\"{it_transformer_name}\",\n",
    "#     anchors=it_anchor_latents[\"absolute\"].to(device),\n",
    "#     transformer=transformer,\n",
    "# )\n",
    "# subsample_anchors = it_latents[\"relative\"][:31, :]\n",
    "# for i_sample, sample in enumerate(it_samples):\n",
    "#     if sample[\"target\"] == 3:\n",
    "#         continue\n",
    "#     for embed_type in (\"relative\", \"absolute\"):\n",
    "#         latents = it_latents[embed_type]\n",
    "#         latents = torch.cat([latents[i_sample, :].unsqueeze(0), subsample_anchors], dim=0)\n",
    "#         classifier = train_classifiers[SEEDS[0]][embed_type][\"en\"]\n",
    "#         print(\n",
    "#             embed_type,\n",
    "#             classifier(latents)[0].item(),\n",
    "#             sample[\"class\"],\n",
    "#         )\n",
    "#     print()\n",
    "#     if i_sample > 100:\n",
    "#         break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
