{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "20a0de76",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "import itertools\n",
    "from pathlib import Path\n",
    "import random\n",
    "\n",
    "import jsonlines\n",
    "\n",
    "import torch\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import cohen_kappa_score\n",
    "\n",
    "from tasks.mucow_wmt19_source import MucowWMT19ContrastiveConditioningTask\n",
    "from tests.mock_models import DictTranslationModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "416dc45e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Define data paths\n",
    "from translation_models.fairseq_models import FairseqScoringModel\n",
    "\n",
    "mucow_path = Path(\".\") / \"data\" / \"mucow_wmt19\"\n",
    "\n",
    "mucow_enru_log_path = mucow_path / \"mucow_classic.transformer.wmt19.en-ru.ensemble.log\"\n",
    "\n",
    "human_annotations_path = mucow_path / \"human_annotations\"\n",
    "mucow_enru_annotator1_path = human_annotations_path / \"en-ru.annotator1.jsonl\"\n",
    "mucow_enru_annotator2_path = human_annotations_path / \"en-ru.annotator2.jsonl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "7624b42d",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load annotations\n",
    "with jsonlines.open(mucow_enru_annotator1_path) as f:\n",
    "  annotations1 = {line[\"Sample ID\"]: line for line in f}\n",
    "with jsonlines.open(mucow_enru_annotator2_path) as f:\n",
    "  annotations2 = {line[\"Sample ID\"]: line for line in f}\n",
    "\n",
    "# Flatten labels\n",
    "for key in annotations1:\n",
    "    annotations1[key][\"label\"] = annotations1[key][\"label\"][0]\n",
    "for key in annotations2:\n",
    "    annotations2[key][\"label\"] = annotations2[key][\"label\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "36029a40",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of annotations: 90 + 90\n"
     ]
    }
   ],
   "source": [
    "# Remove samples that were only partially annotated\n",
    "for key in list(annotations1.keys()):\n",
    "    if key not in annotations2:\n",
    "        del annotations1[key]\n",
    "for key in list(annotations2.keys()):\n",
    "    if key not in annotations1:\n",
    "        del annotations2[key]\n",
    "print(f\"Number of annotations: {len(annotations1)} + {len(annotations2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "344c62d3",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5696465696465697\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement before data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f41e286a",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of annotations: 86 + 86\n"
     ]
    }
   ],
   "source": [
    "# Clean data\n",
    "skipped_keys = set()\n",
    "for annotations in [annotations1, annotations2]:\n",
    "    for key in keys:\n",
    "        # Treat neutral as correct\n",
    "        if annotations[key][\"label\"] == \"Both / Neutral / Ambiguous\":\n",
    "            annotations[key][\"label\"] = \"Correct Sense\"\n",
    "        # Treat bad translations as wrong\n",
    "        if annotations[key][\"label\"] == \"Translation too bad to tell / Third sense\":\n",
    "            annotations[key][\"label\"] = \"Wrong Sense\"\n",
    "        # Skip bad samples\n",
    "        if annotations[key][\"label\"] == \"Bad sample / Ill-defined senses\":\n",
    "            skipped_keys.add(key)\n",
    "for annotations in [annotations1, annotations2]:\n",
    "    for key in skipped_keys:\n",
    "        if key in annotations:\n",
    "            del annotations[key]\n",
    "print(f\"Number of annotations: {len(annotations1)} + {len(annotations2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ad0fbebb",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8767908309455588\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement after data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "e447ad46",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Merge annotations\n",
    "annotations = list(itertools.chain(annotations1.values(), annotations2.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b2341ff5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load task\n",
    "mucow_enru = MucowWMT19ContrastiveConditioningTask(\n",
    "    tgt_language=\"ru\",\n",
    "    evaluator_model=None,\n",
    "    caching=False,\n",
    "    source_data_path=(mucow_path / \"en-ru.insertions.roberta-large.source_data.top10.jsonl\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "00d7a03b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full number of uncovered samples:  170\n"
     ]
    }
   ],
   "source": [
    "# Create list of annotated (= uncovered) samples\n",
    "uncovered_samples = []\n",
    "all_samples_dict = {(sample.src_sentence, sample.src_word): sample for sample in mucow_enru.samples}\n",
    "for annotation in annotations:\n",
    "    try:\n",
    "        sample = all_samples_dict[(annotation[\"Source Sentence\"], annotation[\"Word\"])]\n",
    "    except KeyError:  # Google Sheets removes leading apostrophe\n",
    "        sample = all_samples_dict.get((\"'\" + annotation[\"Source Sentence\"], annotation[\"Word\"]), None)\n",
    "    if sample is None:\n",
    "        continue\n",
    "    sample._gold_label = annotation[\"label\"] == \"Correct Sense\"\n",
    "    sample.translation = annotation[\"Translation\"]\n",
    "    uncovered_samples.append(sample)\n",
    "print(\"Full number of uncovered samples: \", len(uncovered_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "28ebd751",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Proportion of covered samples 0.8155136268343816\n",
      "Full number of covered samples:  389\n"
     ]
    }
   ],
   "source": [
    "# Create list of unannotated (= covered) samples based on log file\n",
    "num_samples = 0\n",
    "covered_samples = []\n",
    "all_samples_dict = {(sample.src_sentence, sample.src_word): sample for sample in mucow_enru.samples}\n",
    "with jsonlines.open(mucow_enru_log_path) as f:\n",
    "    for line in f:\n",
    "        sample = all_samples_dict.get((line[\"sentence\"], line[\"src_word\"]), None)\n",
    "        if sample is None:\n",
    "            continue  # contrastive conditioning not applicable\n",
    "        if line[\"corpus\"] == \"opensubs\":  # Only evaluate on in-domain samples because they have higher quality\n",
    "            continue\n",
    "        num_samples += 1\n",
    "        if line[\"is_unknown\"]:  # = uncovered\n",
    "            continue\n",
    "        sample._gold_label = line[\"is_correct\"]\n",
    "        sample.translation = line[\"translation\"]\n",
    "        covered_samples.append(sample)\n",
    "random.seed(42)\n",
    "coverage = len(covered_samples) / num_samples\n",
    "print(\"Proportion of covered samples\", coverage)\n",
    "print(\"Full number of covered samples: \", len(covered_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "328b1057",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing on 921 covered samples and 170 uncovered samples\n"
     ]
    }
   ],
   "source": [
    "# Sample a proportionate amount of covered samples\n",
    "_covered_samples = []\n",
    "for _ in range(int(len(uncovered_samples) * (1 / (1 - coverage)))):\n",
    "    _covered_samples.append(random.choice(covered_samples))\n",
    "covered_samples = _covered_samples\n",
    "print(f\"Testing on {len(covered_samples)} covered samples and {len(uncovered_samples)} uncovered samples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "4740a2d2",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Proportion of agreement:  0.8973418881759854\n"
     ]
    }
   ],
   "source": [
    "# Evaluate classic MuCoW\n",
    "# Count all covered samples as agreements; judge all unknown samples as wrong translations\n",
    "num_agreements = len(covered_samples) + sum(1 for sample in uncovered_samples if not sample._gold_label)\n",
    "proportion_of_agreement = num_agreements / (len(covered_samples) + len(uncovered_samples))\n",
    "print(\"Proportion of agreement: \", proportion_of_agreement)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "7b4c7869",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/user/anonymized/.cache/torch/hub/pytorch_fairseq_master\n",
      "2021-05-12 11:48:43 | INFO | fairseq.file_utils | loading archive file https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz from cache at /home/user/anonymized/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9\n",
      "2021-05-12 11:48:46 | INFO | fairseq.tasks.translation | [en] dictionary: 31640 types\n",
      "2021-05-12 11:48:46 | INFO | fairseq.tasks.translation | [ru] dictionary: 31232 types\n",
      "2021-05-12 11:49:04 | INFO | fairseq.models.fairseq_model | Namespace(activation_dropout=0.0, activation_fn='relu', adam_betas='(0.9, 0.98)', adam_eps=1e-08, adaptive_input=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='transformer_wmt_en_de_big', attention_dropout=0.1, batch_size=None, bpe='fastbpe', bpe_codes='/home/user/anonymized/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9/bpecodes', bucket_cap_mb=25, clip_norm=0.0, cpu=False, criterion='label_smoothed_cross_entropy', cross_self_attention=False, data='/home/user/anonymized/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9', ddp_backend='c10d', decoder_attention_heads=16, decoder_embed_dim=1024, decoder_embed_path=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layerdrop=0, decoder_layers=6, decoder_layers_to_keep=None, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, distributed_backend='nccl', distributed_init_method='tcp://localhost:14352', distributed_port=-1, distributed_rank=0, distributed_world_size=2, dropout=0.2, encoder_attention_heads=16, encoder_embed_dim=1024, encoder_embed_path=None, encoder_ffn_embed_dim=8192, encoder_layerdrop=0, encoder_layers=6, encoder_layers_to_keep=None, encoder_learned_pos=False, encoder_normalize_before=False, eval_bleu_detok='space', eval_bleu_remove_bpe=None, eval_tokenized_bleu=False, extra_data='', fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, ignore_prefix_size=0, keep_interval_updates=-1, keep_last_epochs=-1, label_smoothing=0.1, layernorm_embedding=False, lazy_load=False, left_pad_source=True, left_pad_target=False, log_format='simple', log_interval=100, lr=[0.0007], lr_scheduler='inverse_sqrt', lr_shrink=0.1, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=1024, max_target_positions=1024, max_tokens=3584, max_update=201700, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, momentum=0.99, moses_no_dash_splits=False, moses_no_escape=False, no_cross_attention=False, no_epoch_checkpoints=False, no_progress_bar=True, no_save=False, no_scale_embedding=False, no_token_positional_embeddings=False, num_batch_buckets=0, num_workers=0, optimizer='adam', optimizer_overrides='{}', quant_noise_pq=0, quant_noise_pq_block_size=8, quant_noise_scalar=0, raw_text=False, relu_dropout=0.0, reset_lr_scheduler=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='/checkpoint/edunov/20190403/wmt19en2ru.btsample5.commoncrawl.ffn8192.transformer_wmt_en_de_big_bsz3584_lr0.0007_dr0.2_size_updates200000_seed2_lbsm0.1_size_sa0_upsample4/finetune/', save_interval=1, save_interval_updates=200, seed=2, sentence_avg=False, share_all_embeddings=False, share_decoder_input_output_embed=True, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='ru', task='translation', tensorboard_logdir='', threshold_loss_scale=None, tie_adaptive_weights=False, tokenizer='moses', train_subset='train', truncate_source=False, update_freq=[1], upsample_primary=1, use_old_adam=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=4000, weight_decay=0.0)\n",
      "2021-05-12 11:49:06 | INFO | root | Scoring translations ...\n",
      "100%|██████████| 1091/1091 [04:50<00:00,  3.76it/s]\n"
     ]
    }
   ],
   "source": [
    "# Run contrastive conditioning\n",
    "hub_interface = torch.hub.load(\n",
    "    repo_or_dir='pytorch/fairseq',\n",
    "    model='transformer.wmt19.en-ru',\n",
    "    checkpoint_file=\"model1.pt:model2.pt:model3.pt:model4.pt\",\n",
    "    tokenizer='moses',\n",
    "    bpe='fastbpe',\n",
    ")\n",
    "evaluator_name = 'transformer.wmt19.en-ru.ensemble'\n",
    "evaluator_model = FairseqScoringModel(\n",
    "    name=evaluator_name,\n",
    "    model=hub_interface,\n",
    "    src_bpe_codes=Path(\".\").parent / \"models\" / \"wmt19.en-ru.ffn8192/en24k.fastbpe.code\",\n",
    "    tgt_bpe_codes=Path(\".\").parent / \"models\" / \"wmt19.ru-en.ffn8192/ru24k.fastbpe.code\",\n",
    ")\n",
    "\n",
    "mucow_enru.samples = uncovered_samples + covered_samples\n",
    "mucow_enru.categories = {sample.category for sample in mucow_enru.samples}\n",
    "mucow_enru.category_wise_weighting = True\n",
    "mucow_enru.evaluator_model = evaluator_model\n",
    "\n",
    "translations = DictTranslationModel({sample.src_sentence: sample.translation for sample in mucow_enru.samples})\n",
    "contrastive_conditioning_weighted_evaluated_samples = mucow_enru.evaluate(translations).samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "4ccb8760",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Create unweighted contrastive conditioning samples\n",
    "contrastive_conditioning_unweighted_evaluated_samples = deepcopy(contrastive_conditioning_weighted_evaluated_samples)\n",
    "for sample in contrastive_conditioning_unweighted_evaluated_samples:\n",
    "    sample.weight = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "34cfa991",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       Wrong      0.394     0.429     0.411     126.0\n",
      "     Correct      0.925     0.914     0.919     965.0\n",
      "\n",
      "    accuracy                          0.858    1091.0\n",
      "   macro avg      0.659     0.671     0.665    1091.0\n",
      "weighted avg      0.863     0.858     0.860    1091.0\n",
      "\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       Wrong      0.496     0.437     0.465   26743.0\n",
      "     Correct      0.957     0.966     0.962  348295.0\n",
      "\n",
      "    accuracy                          0.928  375038.0\n",
      "   macro avg      0.727     0.702     0.713  375038.0\n",
      "weighted avg      0.924     0.928     0.926  375038.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Evaluate contrastive conditioning\n",
    "for evaluated_samples in [\n",
    "    contrastive_conditioning_unweighted_evaluated_samples,\n",
    "    contrastive_conditioning_weighted_evaluated_samples,\n",
    "]:\n",
    "    predicted_labels = []\n",
    "    gold_labels = []\n",
    "    weights = []\n",
    "    for sample in evaluated_samples:\n",
    "        gold_labels.append(int(sample._gold_label))\n",
    "        predicted_labels.append(int(sample.is_correct))\n",
    "        weights.append(getattr(sample, \"weight\", 1))\n",
    "    class_labels = [0, 1]\n",
    "    target_names = [\"Wrong\", \"Correct\"]\n",
    "    print(classification_report(\n",
    "        y_true=gold_labels,\n",
    "        y_pred=predicted_labels,\n",
    "        labels=class_labels,\n",
    "        target_names=target_names,\n",
    "        sample_weight=weights,\n",
    "        zero_division=True,\n",
    "        digits=3,\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "343fd42c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
