{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "20a0de76",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "from copy import deepcopy\n",
    "\n",
    "import jsonlines\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import cohen_kappa_score\n",
    "\n",
    "from tasks.winomt_classic import WinomtClassicTask\n",
    "from tasks.winomt_classic_utils.language_predictors.util import WB_GENDER_TYPES, GENDER\n",
    "from tasks.winomt_classic_utils.languages import German\n",
    "from tasks.winomt_source import WinomtContrastiveConditioningTask\n",
    "from tests.mock_models import DictTranslationModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [],
   "source": [
    "# Define data paths\n",
    "from translation_models.fairseq_models import FairseqScoringModel\n",
    "\n",
    "winomt_path = Path(\".\") / \"data\" / \"winomt\"\n",
    "\n",
    "winomt_ende_translations_path = winomt_path / \"aws.de.full.txt\"\n",
    "\n",
    "human_annotations_path = winomt_path / \"human_annotations\"\n",
    "winomt_ende_annotator1_path = human_annotations_path / \"en-de.annotator1.jsonl\"\n",
    "winomt_ende_annotator2_path = human_annotations_path / \"en-de.annotator2.jsonl\""
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "outputs": [],
   "source": [
    "# Load annotations\n",
    "with jsonlines.open(winomt_ende_annotator1_path) as f:\n",
    "  annotations1 = {line[\"Sample ID\"]: line for line in f}\n",
    "with jsonlines.open(winomt_ende_annotator2_path) as f:\n",
    "  annotations2 = {line[\"Sample ID\"]: line for line in f}\n",
    "\n",
    "# Flatten labels\n",
    "for key in annotations1:\n",
    "    annotations1[key][\"label\"] = annotations1[key][\"label\"][0]\n",
    "for key in annotations2:\n",
    "    annotations2[key][\"label\"] = annotations2[key][\"label\"][0]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "outputs": [],
   "source": [
    "# Remove samples that were only partially annotated\n",
    "for key in list(annotations1.keys()):\n",
    "    if key not in annotations2:\n",
    "        del annotations1[key]\n",
    "for key in list(annotations2.keys()):\n",
    "    if key not in annotations1:\n",
    "        del annotations2[key]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9525335231992406\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement before data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [],
   "source": [
    "# Clean data\n",
    "for annotations in [annotations1, annotations2]:\n",
    "    for key in keys:\n",
    "        # Treat neutral as correct\n",
    "        if annotations[key][\"label\"] == \"Both / Neutral / Ambiguous\":\n",
    "            annotations[key][\"label\"] = annotations[key][\"Gold Gender\"].title()\n",
    "        # Treat bad as wrong\n",
    "        if annotations[key][\"label\"] == \"Translation too bad to tell\":\n",
    "            annotations[key][\"label\"] = \"Male\" if annotations[key][\"Gold Gender\"] == \"female\" else \"Female\""
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9526963103122044\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement after data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "outputs": [],
   "source": [
    "# Merge annotations\n",
    "annotations = list(itertools.chain(annotations1.values(), annotations2.values()))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "outputs": [],
   "source": [
    "# Load translations\n",
    "with open(winomt_ende_translations_path) as f:\n",
    "    translations = {line.split(\" ||| \")[0].strip(): line.split(\" ||| \")[1].strip() for line in f}"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "3888it [00:00, 740372.92it/s]\n",
      "2021-05-08 11:20:36 | WARNING | root | less than two dets found: Der Vorgesetzter hat den Buchhalter wegen Nichteinhaltung seines Protokolls beredet.\n",
      "2021-05-08 11:20:41 | WARNING | root | less than two dets found: Der Analyst konnte den Rezeptionisten wegen seines schlechten Berichts nicht beeindrucken.\n",
      "2021-05-08 11:20:41 | WARNING | root | less than two dets found: Der Vorgesetzter hat den Buchhalter wegen seines schlechten Verhaltens gescheitert.\n",
      "2021-05-08 11:20:41 | WARNING | root | less than two dets found: Der Vorgesetzte hat den Buchhalter wegen Nichtbeachtung ihres Protokolls beschattet.\n",
      "2021-05-08 11:20:43 | WARNING | root | less than two dets found: Der Berater wurde befohlen, den CEO zu untersuchen und ging nach Hause.\n",
      "2021-05-08 11:20:45 | WARNING | root | less than two dets found: Jemand fragte den Rezeptionisten, ob er sich anmelden müsse.\n",
      "2021-05-08 11:20:45 | WARNING | root | less than two dets found: Jemand wurde an den Spezialisten verwiesen, weil er Kenntnis von seltenen Hautzuständen hatte.\n",
      "2021-05-08 11:20:45 | WARNING | root | less than two dets found: Jemand wurde an den Spezialisten verwiesen, weil sie Kenntnis von seltenen Hautzuständen hatte.\n",
      "2021-05-08 11:20:45 | WARNING | root | less than two dets found: Jemand wurde an den Spezialisten verwiesen, weil sie Kenntnis von seltenen Hautzuständen hatten.\n",
      "2021-05-08 11:20:45 | WARNING | root | less than two dets found: Der Elektriker warnte den Hausbesitzer, nichts selbst zu verdrahten.\n",
      "2021-05-08 11:20:45 | WARNING | root | less than two dets found: Der Elektriker warnte den Hausbesitzer, nicht zu versuchen, etwas selbst neu zu verdrahten.\n",
      "2021-05-08 11:20:45 | WARNING | root | less than two dets found: Der Elektriker warnte den Hausbesitzer, nichts selbst zu verdrahten.\n"
     ]
    }
   ],
   "source": [
    "# Run classic (translation-based) WinoMT\n",
    "winomt_classic_ende = WinomtClassicTask(\n",
    "    language=German,\n",
    "    skip_neutral_gold=False,\n",
    "    caching=False,\n",
    "    verbose=True,\n",
    ")\n",
    "classic_evaluated_samples = winomt_classic_ende.evaluate(DictTranslationModel(translations)).samples"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/user/anonymized/.cache/torch/hub/pytorch_fairseq_master\n",
      "2021-05-08 11:20:47 | INFO | fairseq.file_utils | loading archive file https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz from cache at /home/user/anonymized/.cache/torch/pytorch_fairseq/0695ef328ddefcb8cbcfabc3196182f59c0e41e0468b10cc0db2ae9c91881fcc.bb1be17de4233e13870bd7d6065bfdb03fca0a51dd0f5d0b7edf5c188eda71f1\n",
      "2021-05-08 11:20:49 | INFO | fairseq.tasks.translation | [en] dictionary: 42024 types\n",
      "2021-05-08 11:20:49 | INFO | fairseq.tasks.translation | [de] dictionary: 42024 types\n",
      "2021-05-08 11:21:05 | INFO | fairseq.models.fairseq_model | Namespace(activation_dropout=0.0, activation_fn='relu', adam_betas='(0.9, 0.98)', adam_eps=1e-08, adaptive_input=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='transformer_wmt_en_de_big', attention_dropout=0.1, batch_size=None, bpe='fastbpe', bpe_codes='/home/user/anonymized/.cache/torch/pytorch_fairseq/0695ef328ddefcb8cbcfabc3196182f59c0e41e0468b10cc0db2ae9c91881fcc.bb1be17de4233e13870bd7d6065bfdb03fca0a51dd0f5d0b7edf5c188eda71f1/bpecodes', bucket_cap_mb=25, clip_norm=0.0, cpu=False, criterion='label_smoothed_cross_entropy', cross_self_attention=False, data='/home/user/anonymized/.cache/torch/pytorch_fairseq/0695ef328ddefcb8cbcfabc3196182f59c0e41e0468b10cc0db2ae9c91881fcc.bb1be17de4233e13870bd7d6065bfdb03fca0a51dd0f5d0b7edf5c188eda71f1', ddp_backend='c10d', decoder_attention_heads=16, decoder_embed_dim=1024, decoder_embed_path=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layerdrop=0, decoder_layers=6, decoder_layers_to_keep=None, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, distributed_backend='nccl', distributed_init_method='tcp://localhost:10731', distributed_port=-1, distributed_rank=0, distributed_world_size=2, dropout=0.2, encoder_attention_heads=16, encoder_embed_dim=1024, encoder_embed_path=None, encoder_ffn_embed_dim=8192, encoder_layerdrop=0, encoder_layers=6, encoder_layers_to_keep=None, encoder_learned_pos=False, encoder_normalize_before=False, eval_bleu_detok='space', eval_bleu_remove_bpe=None, eval_tokenized_bleu=False, extra_data='', fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, ignore_prefix_size=0, keep_interval_updates=-1, keep_last_epochs=-1, label_smoothing=0.1, layernorm_embedding=False, lazy_load=False, left_pad_source=True, left_pad_target=False, log_format='simple', log_interval=100, lr=[0.0007], lr_scheduler='inverse_sqrt', lr_shrink=0.1, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=1024, max_target_positions=1024, max_tokens=3584, max_update=202200, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, momentum=0.99, moses_no_dash_splits=False, moses_no_escape=False, no_cross_attention=False, no_epoch_checkpoints=False, no_progress_bar=True, no_save=False, no_scale_embedding=False, no_token_positional_embeddings=False, num_batch_buckets=0, num_workers=0, optimizer='adam', optimizer_overrides='{}', quant_noise_pq=0, quant_noise_pq_block_size=8, quant_noise_scalar=0, raw_text=False, relu_dropout=0.0, reset_lr_scheduler=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='/checkpoint/edunov/20190403/wmt19en2de.btsample5.ffn8192.transformer_wmt_en_de_big_bsz3584_lr0.0007_dr0.2_size_updates200000_seed2_lbsm0.1_size_sa1_upsample4/finetune', save_interval=1, save_interval_updates=200, seed=2, sentence_avg=False, share_all_embeddings=True, share_decoder_input_output_embed=True, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='de', task='translation', tensorboard_logdir='', threshold_loss_scale=None, tie_adaptive_weights=False, tokenizer='moses', train_subset='train', truncate_source=False, update_freq=[1], upsample_primary=1, use_old_adam=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=4000, weight_decay=0.0)\n"
     ]
    }
   ],
   "source": [
    "# Run source-scoring\n",
    "DEBUG = False\n",
    "if DEBUG:\n",
    "    model_path = Path(\".\").parent / \"tests\" / \"models\" / \"toy_fairseq_en-de\"\n",
    "    evaluator_model = FairseqScoringModel(\n",
    "        model_path.name,\n",
    "        model_name_or_path=model_path,\n",
    "        tokenizer=\"moses\",\n",
    "        bpe=\"fastbpe\",\n",
    "    )\n",
    "else:\n",
    "    hub_interface = torch.hub.load(\n",
    "        repo_or_dir='pytorch/fairseq',\n",
    "        model='transformer.wmt19.en-de',\n",
    "        checkpoint_file=\"model1.pt:model2.pt:model3.pt:model4.pt\",\n",
    "        tokenizer='moses',\n",
    "        bpe='fastbpe',\n",
    "    )\n",
    "    evaluator_name = 'transformer.wmt19.en-de.ensemble'\n",
    "    evaluator_model = FairseqScoringModel(name=evaluator_name, model=hub_interface)\n",
    "\n",
    "winomt_contrastive_conditioning = WinomtContrastiveConditioningTask(\n",
    "    evaluator_model=evaluator_model,\n",
    "    skip_neutral_gold=False,\n",
    "    caching=False,\n",
    "    category_wise_weighting=True,\n",
    ")\n",
    "contrastive_conditioning_weighted_evaluated_samples = winomt_contrastive_conditioning.evaluate(DictTranslationModel(translations)).samples"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [],
   "source": [
    "# Create unweighted contrastive conditioning samples\n",
    "contrastive_conditioning_unweighted_evaluated_samples = deepcopy(contrastive_conditioning_weighted_evaluated_samples)\n",
    "for sample in contrastive_conditioning_unweighted_evaluated_samples:\n",
    "    sample.weight = 1"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "        male      0.962     0.863     0.910     321.0\n",
      "      female      0.607     0.861     0.712      79.0\n",
      "\n",
      "    accuracy                          0.863     400.0\n",
      "   macro avg      0.784     0.862     0.811     400.0\n",
      "weighted avg      0.892     0.863     0.871     400.0\n",
      "\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "        male      0.942     0.969     0.955     321.0\n",
      "      female      0.857     0.759     0.805      79.0\n",
      "\n",
      "    accuracy                          0.927     400.0\n",
      "   macro avg      0.900     0.864     0.880     400.0\n",
      "weighted avg      0.926     0.927     0.926     400.0\n",
      "\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "        male      0.982     0.991     0.987  343015.0\n",
      "      female      0.914     0.842     0.876   38913.0\n",
      "\n",
      "    accuracy                          0.976  381928.0\n",
      "   macro avg      0.948     0.916     0.931  381928.0\n",
      "weighted avg      0.975     0.976     0.975  381928.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Evaluate\n",
    "for evaluated_samples in [\n",
    "    classic_evaluated_samples,\n",
    "    contrastive_conditioning_unweighted_evaluated_samples,\n",
    "    contrastive_conditioning_weighted_evaluated_samples,\n",
    "]:\n",
    "    predicted_labels = []\n",
    "    gold_labels = []\n",
    "    weights = []\n",
    "    for annotation in annotations:\n",
    "        gold_labels.append(WB_GENDER_TYPES[annotation[\"label\"].lower()].value)\n",
    "        sample_index = int(annotation[\"Index\"])\n",
    "        evaluated_sample = evaluated_samples[sample_index]\n",
    "        assert evaluated_sample.sentence == annotation[\"Source Sentence\"]\n",
    "        if hasattr(evaluated_sample, \"predicted_gender\"):\n",
    "            predicted_gender = evaluated_sample.predicted_gender.value\n",
    "            # Convert neutral or unknown to gold in order to treat classic WinoMT as fairly as possible\n",
    "            if predicted_gender in {GENDER.neutral.value, GENDER.unknown.value}:\n",
    "                predicted_gender = evaluated_sample.gold_gender.value\n",
    "        else:\n",
    "            if evaluated_sample.is_correct:\n",
    "                predicted_gender = WB_GENDER_TYPES[evaluated_sample.gold_gender].value\n",
    "            else:\n",
    "                predicted_gender = int(not WB_GENDER_TYPES[evaluated_sample.gold_gender].value)\n",
    "        predicted_labels.append(predicted_gender)\n",
    "        weights.append(getattr(evaluated_sample, \"weight\", 1))\n",
    "    class_labels = [gender.value for gender in GENDER][:2]\n",
    "    target_names = [gender.name for gender in GENDER][:2]\n",
    "    print(classification_report(\n",
    "        y_true=gold_labels,\n",
    "        y_pred=predicted_labels,\n",
    "        labels=class_labels,\n",
    "        target_names=target_names,\n",
    "        sample_weight=weights,\n",
    "        zero_division=True,\n",
    "        digits=3,\n",
    "    ))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}