{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "20a0de76",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "import itertools\n",
    "from pathlib import Path\n",
    "import random\n",
    "\n",
    "import jsonlines\n",
    "\n",
    "import torch\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import cohen_kappa_score\n",
    "\n",
    "from tasks.mucow_wmt19_source import MucowWMT19ContrastiveConditioningTask\n",
    "from tests.mock_models import DictTranslationModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fe37861c",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Define data paths\n",
    "from translation_models.fairseq_models import FairseqScoringModel\n",
    "\n",
    "mucow_path = Path(\".\") / \"data\" / \"mucow_wmt19\"\n",
    "\n",
    "mucow_ende_log_path = mucow_path / \"mucow_classic.transformer.wmt19.en-de.ensemble.log\"\n",
    "\n",
    "human_annotations_path = mucow_path / \"human_annotations\"\n",
    "mucow_ende_annotator1_path = human_annotations_path / \"en-de.annotator1.jsonl\"\n",
    "mucow_ende_annotator2_path = human_annotations_path / \"en-de.annotator2.jsonl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0b5f39f5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load annotations\n",
    "with jsonlines.open(mucow_ende_annotator1_path) as f:\n",
    "  annotations1 = {line[\"Sample ID\"]: line for line in f}\n",
    "with jsonlines.open(mucow_ende_annotator2_path) as f:\n",
    "  annotations2 = {line[\"Sample ID\"]: line for line in f}\n",
    "\n",
    "# Flatten labels\n",
    "for key in annotations1:\n",
    "    annotations1[key][\"label\"] = annotations1[key][\"label\"][0]\n",
    "for key in annotations2:\n",
    "    annotations2[key][\"label\"] = annotations2[key][\"label\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "517e196e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of annotations: 185 + 185\n"
     ]
    }
   ],
   "source": [
    "# Remove samples that were only partially annotated\n",
    "for key in list(annotations1.keys()):\n",
    "    if key not in annotations2:\n",
    "        del annotations1[key]\n",
    "for key in list(annotations2.keys()):\n",
    "    if key not in annotations1:\n",
    "        del annotations2[key]\n",
    "print(f\"Number of annotations: {len(annotations1)} + {len(annotations2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9eb632c7",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.37623762376237624\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement before data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fa3081e0",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of annotations: 96 + 96\n"
     ]
    }
   ],
   "source": [
    "# Clean data\n",
    "skipped_keys = set()\n",
    "for annotations in [annotations1, annotations2]:\n",
    "    for key in keys:\n",
    "        # Treat neutral as correct\n",
    "        if annotations[key][\"label\"] == \"Both / Neutral / Ambiguous\":\n",
    "            annotations[key][\"label\"] = \"Correct Sense\"\n",
    "        # Treat bad translations as wrong\n",
    "        if annotations[key][\"label\"] == \"Translation too bad to tell / Third sense\":\n",
    "            annotations[key][\"label\"] = \"Wrong Sense\"\n",
    "        # Skip bad samples\n",
    "        if annotations[key][\"label\"] == \"Bad sample / Ill-defined senses\":\n",
    "            skipped_keys.add(key)\n",
    "for annotations in [annotations1, annotations2]:\n",
    "    for key in skipped_keys:\n",
    "        if key in annotations:\n",
    "            del annotations[key]\n",
    "print(f\"Number of annotations: {len(annotations1)} + {len(annotations2)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2dc58bde",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9175257731958762\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement after data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f4750d3f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Merge annotations\n",
    "annotations = list(itertools.chain(annotations1.values(), annotations2.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7eafc174",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load task\n",
    "mucow_ende = MucowWMT19ContrastiveConditioningTask(\n",
    "    tgt_language=\"de\",\n",
    "    evaluator_model=None,\n",
    "    caching=False,\n",
    "    source_data_path=(mucow_path / \"en-de.insertions.roberta-large.source_data.top10.jsonl\"),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c5efe96f",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Full number of uncovered samples:  192\n"
     ]
    }
   ],
   "source": [
    "# Create list of annotated (= uncovered) samples\n",
    "uncovered_samples = []\n",
    "all_samples_dict = {(sample.src_sentence, sample.src_word): sample for sample in mucow_ende.samples}\n",
    "for annotation in annotations:\n",
    "    try:\n",
    "        sample = all_samples_dict[(annotation[\"Source Sentence\"], annotation[\"Word\"])]\n",
    "    except KeyError:  # Google Sheets removes leading apostrophe\n",
    "        sample = all_samples_dict.get((\"'\" + annotation[\"Source Sentence\"], annotation[\"Word\"]), None)\n",
    "    if sample is None:\n",
    "        continue\n",
    "    sample._gold_label = annotation[\"label\"] == \"Correct Sense\"\n",
    "    sample.translation = annotation[\"Translation\"]\n",
    "    uncovered_samples.append(sample)\n",
    "print(\"Full number of uncovered samples: \", len(uncovered_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "609e58c9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Proportion of covered samples 0.7886486486486487\n",
      "Full number of covered samples:  1459\n"
     ]
    }
   ],
   "source": [
    "# Create list of unannotated (= covered) samples based on log file\n",
    "num_samples = 0\n",
    "covered_samples = []\n",
    "all_samples_dict = {(sample.src_sentence, sample.src_word): sample for sample in mucow_ende.samples}\n",
    "with jsonlines.open(mucow_ende_log_path) as f:\n",
    "    for line in f:\n",
    "        sample = all_samples_dict.get((line[\"sentence\"], line[\"src_word\"]), None)\n",
    "        if sample is None:\n",
    "            continue  # contrastive conditioning not applicable\n",
    "        if line[\"corpus\"] == \"opensubs\":  # Only evaluate on in-domain samples because they have higher quality\n",
    "            continue\n",
    "        num_samples += 1\n",
    "        if line[\"is_unknown\"]:  # = uncovered\n",
    "            continue\n",
    "        sample._gold_label = line[\"is_correct\"]\n",
    "        sample.translation = line[\"translation\"]\n",
    "        covered_samples.append(sample)\n",
    "random.seed(42)\n",
    "coverage = len(covered_samples) / num_samples\n",
    "print(\"Proportion of covered samples\", coverage)\n",
    "print(\"Full number of covered samples: \", len(covered_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c64fd193",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing on 908 covered samples and 192 uncovered samples\n"
     ]
    }
   ],
   "source": [
    "# Sample a proportionate amount of covered samples\n",
    "_covered_samples = []\n",
    "for _ in range(int(len(uncovered_samples) * (1 / (1 - coverage)))):\n",
    "    _covered_samples.append(random.choice(covered_samples))\n",
    "covered_samples = _covered_samples\n",
    "print(f\"Testing on {len(covered_samples)} covered samples and {len(uncovered_samples)} uncovered samples\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "31cfed43",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Proportion of agreement:  0.8381818181818181\n"
     ]
    }
   ],
   "source": [
    "# Evaluate classic MuCoW\n",
    "# Count all covered samples as agreements; judge all unknown samples as wrong translations\n",
    "num_agreements = len(covered_samples) + sum(1 for sample in uncovered_samples if not sample._gold_label)\n",
    "proportion_of_agreement = num_agreements / (len(covered_samples) + len(uncovered_samples))\n",
    "print(\"Proportion of agreement: \", proportion_of_agreement)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5d32a3de",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/user/anonymized/.cache/torch/hub/pytorch_fairseq_master\n",
      "2021-05-12 12:17:18 | INFO | fairseq.file_utils | loading archive file https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz from cache at /home/user/anonymized/.cache/torch/pytorch_fairseq/0695ef328ddefcb8cbcfabc3196182f59c0e41e0468b10cc0db2ae9c91881fcc.bb1be17de4233e13870bd7d6065bfdb03fca0a51dd0f5d0b7edf5c188eda71f1\n",
      "2021-05-12 12:17:21 | INFO | fairseq.tasks.translation | [en] dictionary: 42024 types\n",
      "2021-05-12 12:17:21 | INFO | fairseq.tasks.translation | [de] dictionary: 42024 types\n",
      "2021-05-12 12:17:39 | INFO | fairseq.models.fairseq_model | Namespace(activation_dropout=0.0, activation_fn='relu', adam_betas='(0.9, 0.98)', adam_eps=1e-08, adaptive_input=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='transformer_wmt_en_de_big', attention_dropout=0.1, batch_size=None, bpe='fastbpe', bpe_codes='/home/user/anonymized/.cache/torch/pytorch_fairseq/0695ef328ddefcb8cbcfabc3196182f59c0e41e0468b10cc0db2ae9c91881fcc.bb1be17de4233e13870bd7d6065bfdb03fca0a51dd0f5d0b7edf5c188eda71f1/bpecodes', bucket_cap_mb=25, clip_norm=0.0, cpu=False, criterion='label_smoothed_cross_entropy', cross_self_attention=False, data='/home/user/anonymized/.cache/torch/pytorch_fairseq/0695ef328ddefcb8cbcfabc3196182f59c0e41e0468b10cc0db2ae9c91881fcc.bb1be17de4233e13870bd7d6065bfdb03fca0a51dd0f5d0b7edf5c188eda71f1', ddp_backend='c10d', decoder_attention_heads=16, decoder_embed_dim=1024, decoder_embed_path=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layerdrop=0, decoder_layers=6, decoder_layers_to_keep=None, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, distributed_backend='nccl', distributed_init_method='tcp://localhost:10731', distributed_port=-1, distributed_rank=0, distributed_world_size=2, dropout=0.2, encoder_attention_heads=16, encoder_embed_dim=1024, encoder_embed_path=None, encoder_ffn_embed_dim=8192, encoder_layerdrop=0, encoder_layers=6, encoder_layers_to_keep=None, encoder_learned_pos=False, encoder_normalize_before=False, eval_bleu_detok='space', eval_bleu_remove_bpe=None, eval_tokenized_bleu=False, extra_data='', fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, ignore_prefix_size=0, keep_interval_updates=-1, keep_last_epochs=-1, label_smoothing=0.1, layernorm_embedding=False, lazy_load=False, left_pad_source=True, left_pad_target=False, log_format='simple', log_interval=100, lr=[0.0007], lr_scheduler='inverse_sqrt', lr_shrink=0.1, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=1024, max_target_positions=1024, max_tokens=3584, max_update=202200, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, momentum=0.99, moses_no_dash_splits=False, moses_no_escape=False, no_cross_attention=False, no_epoch_checkpoints=False, no_progress_bar=True, no_save=False, no_scale_embedding=False, no_token_positional_embeddings=False, num_batch_buckets=0, num_workers=0, optimizer='adam', optimizer_overrides='{}', quant_noise_pq=0, quant_noise_pq_block_size=8, quant_noise_scalar=0, raw_text=False, relu_dropout=0.0, reset_lr_scheduler=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='/checkpoint/edunov/20190403/wmt19en2de.btsample5.ffn8192.transformer_wmt_en_de_big_bsz3584_lr0.0007_dr0.2_size_updates200000_seed2_lbsm0.1_size_sa1_upsample4/finetune', save_interval=1, save_interval_updates=200, seed=2, sentence_avg=False, share_all_embeddings=True, share_decoder_input_output_embed=True, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='de', task='translation', tensorboard_logdir='', threshold_loss_scale=None, tie_adaptive_weights=False, tokenizer='moses', train_subset='train', truncate_source=False, update_freq=[1], upsample_primary=1, use_old_adam=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=4000, weight_decay=0.0)\n",
      "2021-05-12 12:17:43 | INFO | root | Scoring translations ...\n",
      "100%|██████████| 1100/1100 [05:04<00:00,  3.61it/s]\n"
     ]
    }
   ],
   "source": [
    "# Run contrastive conditioning\n",
    "DEBUG = False\n",
    "if DEBUG:\n",
    "    model_path = Path(\".\").parent / \"tests\" / \"models\" / \"toy_fairseq_en-de\"\n",
    "    evaluator_model = FairseqScoringModel(\n",
    "        model_path.name,\n",
    "        model_name_or_path=model_path,\n",
    "        tokenizer=\"moses\",\n",
    "        bpe=\"fastbpe\",\n",
    "    )\n",
    "else:\n",
    "    hub_interface = torch.hub.load(\n",
    "        repo_or_dir='pytorch/fairseq',\n",
    "        model='transformer.wmt19.en-de',\n",
    "        checkpoint_file=\"model1.pt:model2.pt:model3.pt:model4.pt\",\n",
    "        tokenizer='moses',\n",
    "        bpe='fastbpe',\n",
    "    )\n",
    "    evaluator_name = 'transformer.wmt19.en-de.ensemble'\n",
    "    evaluator_model = FairseqScoringModel(name=evaluator_name, model=hub_interface)\n",
    "\n",
    "mucow_ende.samples = uncovered_samples + covered_samples\n",
    "mucow_ende.categories = {sample.category for sample in mucow_ende.samples}\n",
    "mucow_ende.category_wise_weighting = True\n",
    "mucow_ende.evaluator_model = evaluator_model\n",
    "\n",
    "translations = DictTranslationModel({sample.src_sentence: sample.translation for sample in mucow_ende.samples})\n",
    "contrastive_conditioning_weighted_evaluated_samples = mucow_ende.evaluate(translations).samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "2441fcb5",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Create unweighted contrastive conditioning samples\n",
    "contrastive_conditioning_unweighted_evaluated_samples = deepcopy(contrastive_conditioning_weighted_evaluated_samples)\n",
    "for sample in contrastive_conditioning_unweighted_evaluated_samples:\n",
    "    sample.weight = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6f2b7de4",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       Wrong      0.278     0.288     0.283     146.0\n",
      "     Correct      0.890     0.886     0.888     954.0\n",
      "\n",
      "    accuracy                          0.806    1100.0\n",
      "   macro avg      0.584     0.587     0.585    1100.0\n",
      "weighted avg      0.809     0.806     0.808    1100.0\n",
      "\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "       Wrong      0.236     0.190     0.211   27443.0\n",
      "     Correct      0.927     0.944     0.935  298648.0\n",
      "\n",
      "    accuracy                          0.880  326091.0\n",
      "   macro avg      0.582     0.567     0.573  326091.0\n",
      "weighted avg      0.869     0.880     0.874  326091.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Evaluate contrastive conditioning\n",
    "for evaluated_samples in [\n",
    "    contrastive_conditioning_unweighted_evaluated_samples,\n",
    "    contrastive_conditioning_weighted_evaluated_samples,\n",
    "]:\n",
    "    predicted_labels = []\n",
    "    gold_labels = []\n",
    "    weights = []\n",
    "    for sample in evaluated_samples:\n",
    "        gold_labels.append(int(sample._gold_label))\n",
    "        predicted_labels.append(int(sample.is_correct))\n",
    "        weights.append(getattr(sample, \"weight\", 1))\n",
    "    class_labels = [0, 1]\n",
    "    target_names = [\"Wrong\", \"Correct\"]\n",
    "    print(classification_report(\n",
    "        y_true=gold_labels,\n",
    "        y_pred=predicted_labels,\n",
    "        labels=class_labels,\n",
    "        target_names=target_names,\n",
    "        sample_weight=weights,\n",
    "        zero_division=True,\n",
    "        digits=3,\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c829a1c6",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
