{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ecf78deb-5012-4ffe-b925-564f5fda7be8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
    "import torch\n",
    "import numpy as np\n",
    "import json\n",
    "from pyemd import emd\n",
    "from tqdm import tqdm\n",
    "from multiprocessing import Pool\n",
    "\n",
    "from sklearn.metrics import euclidean_distances\n",
    "from sklearn.preprocessing import normalize\n",
    "\n",
    "from wmd import WordMoversDistance\n",
    "\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.tokenize import word_tokenize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9f5879fd-4ae8-4468-bddf-578c72d10bbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-base\")\n",
    "model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-base\")\n",
    "\n",
    "embeddings = model.roberta.embeddings.word_embeddings.weight.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d90fc8fd-09d0-45cb-ab0c-0633c9f7b2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "stopwords_en = set(stopwords.words('english'))\n",
    "stopwords_da = set(stopwords.words('danish'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "84b06557-8545-4088-bc8d-1faae62dc252",
   "metadata": {},
   "outputs": [],
   "source": [
    "wmd = WordMoversDistance(embeddings, n_jobs=-1, verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d8e867e4-4d42-4a82-b428-24043e0bca15",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clean_text(text, lang):\n",
    "    \n",
    "    if lang == 'en':\n",
    "        toks = [word for word in word_tokenize(text) if word not in stopwords_en and word.isalnum()]\n",
    "    elif lang == 'da':\n",
    "        toks = [word for word in word_tokenize(text) if word not in stopwords_da and word.isalnum()]\n",
    "    \n",
    "    text = ' '.join(toks)\n",
    "    return text\n",
    "\n",
    "\n",
    "def tokenize(sent, lang):\n",
    "    \n",
    "    sent = clean_text(sent, lang)\n",
    "    \n",
    "    tokens = tokenizer.tokenize(sent)\n",
    "    token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "    \n",
    "    return token_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8cee15ac-51af-477b-a180-96c7ab4ef8a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/text-to-text/da-en.train.da', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    lines = lines[:100]\n",
    "    data_da = {i:line for i, line in enumerate(lines)}\n",
    "\n",
    "with open('../data/text-to-text/da-en.train.en', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    lines = lines[:100]\n",
    "    data_en = {i:line for i, line in enumerate(lines)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d3e89f33-e08a-44c6-aaf1-85584afcd1e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "ids_da, sent_da = zip(*data_da.items())\n",
    "ids_en, sent_en = zip(*data_en.items())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2898fef6-511e-4067-a545-2a49f28699fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens_da = [tokenize(sent, 'da') for sent in sent_da]\n",
    "tokens_en = [tokenize(sent, 'en') for sent in sent_en]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3cb7a253-6bc8-497f-9403-709e99c8d217",
   "metadata": {},
   "outputs": [],
   "source": [
    "wmd.fit(ids_da, tokens_da)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "438f6af9-3fc7-4416-861e-1fddcfc43730",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [06:34<00:00,  3.95s/it]\n"
     ]
    }
   ],
   "source": [
    "result = {}\n",
    "\n",
    "for i in tqdm(range(len(ids_en))):\n",
    "    id = ids_en[i]\n",
    "    toks = tokens_en[i]\n",
    "    \n",
    "    dists = wmd.predict(toks)\n",
    "    sentresult = {\n",
    "        srcid: dist for (srcid, dist) in zip(wmd.source_ids, dists)\n",
    "    }\n",
    "    \n",
    "    result[id] = sentresult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2942bb5e-b570-4aa8-a4f1-b1526e9900ca",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "with open('../results/text-to-text/da-en-translation.json', 'w') as f:\n",
    "    json.dump(result, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "9c000815-7730-41ab-9fe3-6215efe5c9dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.39 39 100\n"
     ]
    }
   ],
   "source": [
    "corr, tot = 0, 0\n",
    "\n",
    "for k, v in result.items():\n",
    "    da_sorted = sorted(v.items(), key=lambda item: item[1])[0]\n",
    "    \n",
    "    if k == da_sorted[0]:\n",
    "        corr += 1\n",
    "    tot += 1\n",
    "\n",
    "print(corr/float(tot), corr, tot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "253c1673-3b93-4a90-ac4c-3d86da343cf7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "1a404f95-12ec-4bab-bd45-ade63f0977f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁title',\n",
       " '▁:',\n",
       " '▁&',\n",
       " 'quot',\n",
       " ';',\n",
       " '▁Over',\n",
       " 'sigt',\n",
       " '▁over',\n",
       " '▁op',\n",
       " 'sætning',\n",
       " 'er',\n",
       " '▁for',\n",
       " '▁service',\n",
       " 'artikler',\n",
       " '▁og',\n",
       " '▁service',\n",
       " 'artikel',\n",
       " 'komponent',\n",
       " 'er',\n",
       " '▁&',\n",
       " 'quot',\n",
       " ';']"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.tokenize(sent_da[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "dd73e72c-9691-491f-8691-8b0e280b1a11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['▁title',\n",
       " '▁:',\n",
       " '▁Over',\n",
       " 'view',\n",
       " '▁of',\n",
       " '▁Set',\n",
       " 'ups',\n",
       " '▁for',\n",
       " '▁Service',\n",
       " '▁Item',\n",
       " 's',\n",
       " '▁and',\n",
       " '▁Service',\n",
       " '▁Item',\n",
       " '▁Com',\n",
       " 'ponent',\n",
       " 's']"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.tokenize(sent_en[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "013f98c0-81fd-4508-a561-bc3b295efa90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'title : Overview of Setups for Service Items and Service Item Components\\n'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sent_en[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f4237980-8741-416c-ac16-227b49b41ef1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'title : &quot; Oversigt over opsætninger for serviceartikler og serviceartikelkomponenter &quot;\\n'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sent_da[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "613a38aa-3f70-4a86-8e7c-281640ce0f5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c7a12220-da11-4b8f-95c8-e77b7de52a65",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "0495446d-a449-46f1-b289-3a4a62fbfb82",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['title',\n",
       "  'Overview',\n",
       "  'Setups',\n",
       "  'Service',\n",
       "  'Items',\n",
       "  'Service',\n",
       "  'Item',\n",
       "  'Components'],\n",
       " 'title : Overview of Setups for Service Items and Service Item Components\\n')"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "senten = [word for word in word_tokenize(sent_en[0]) if word not in stop_words and word.isalnum()]\n",
    "senten, sent_en[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "8cc95985-e9e7-4fcd-bf84-390483b06409",
   "metadata": {},
   "outputs": [],
   "source": [
    "stopwords_da = stopwords.words('danish')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "8d883559-7123-404a-817d-05a378e94229",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['title',\n",
       "  'quot',\n",
       "  'Oversigt',\n",
       "  'opsætninger',\n",
       "  'serviceartikler',\n",
       "  'serviceartikelkomponenter',\n",
       "  'quot'],\n",
       " 'title : &quot; Oversigt over opsætninger for serviceartikler og serviceartikelkomponenter &quot;\\n')"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentda = \n",
    "sentda, sent_da[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "433b9f11-1904-433a-b1fc-b9a8cb9f333d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "955765d4-1ff3-430f-9cce-2c2e5cd20a8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'!\"#$%&\\'()*+,-./:;<=>?@[\\\\]^_`{|}~'"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "string.punctuation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "b5568a33-c00d-4d9c-b55b-fa89a0301f2d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'title : Overview of Setups for Service Items and Service Item Components+'"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sent_en[0].translate(string.punctuation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e22504b3-56ad-491f-8115-5fbff3156d15",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
