{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "20a0de76",
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "from copy import deepcopy\n",
    "\n",
    "import jsonlines\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import cohen_kappa_score\n",
    "\n",
    "from tasks.winomt_classic import WinomtClassicTask\n",
    "from tasks.winomt_classic_utils.language_predictors.util import WB_GENDER_TYPES, GENDER\n",
    "from tasks.winomt_classic_utils.languages import Russian\n",
    "from tasks.winomt_source import WinomtContrastiveConditioningTask\n",
    "from tests.mock_models import DictTranslationModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "03b225df",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Define data paths\n",
    "from translation_models.fairseq_models import FairseqScoringModel\n",
    "\n",
    "winomt_path = Path(\".\") / \"data\" / \"winomt\"\n",
    "\n",
    "winomt_enru_translations_path = winomt_path / \"google.ru.full.txt\"\n",
    "\n",
    "human_annotations_path = winomt_path / \"human_annotations\"\n",
    "winomt_enru_annotator1_path = human_annotations_path / \"en-ru.annotator1.jsonl\"\n",
    "winomt_enru_annotator2_path = human_annotations_path / \"en-ru.annotator2.jsonl\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c2d9cf3e",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load annotations\n",
    "with jsonlines.open(winomt_enru_annotator1_path) as f:\n",
    "  annotations1 = {line[\"Sample ID\"]: line for line in f}\n",
    "with jsonlines.open(winomt_enru_annotator2_path) as f:\n",
    "  annotations2 = {line[\"Sample ID\"]: line for line in f}\n",
    "\n",
    "# Flatten labels\n",
    "for key in annotations1:\n",
    "    annotations1[key][\"label\"] = annotations1[key][\"label\"][0]\n",
    "for key in annotations2:\n",
    "    annotations2[key][\"label\"] = annotations2[key][\"label\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4bb27287",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Remove samples that were only partially annotated\n",
    "for key in list(annotations1.keys()):\n",
    "    if key not in annotations2:\n",
    "        del annotations1[key]\n",
    "for key in list(annotations2.keys()):\n",
    "    if key not in annotations1:\n",
    "        del annotations2[key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "20cf5a3b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.13114217077964607\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement before data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "614487d9",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Clean data\n",
    "for annotations in [annotations1, annotations2]:\n",
    "    for key in keys:\n",
    "        # Treat neutral as correct\n",
    "        if annotations[key][\"label\"] == \"Both / Neutral / Ambiguous\":\n",
    "            annotations[key][\"label\"] = annotations[key][\"Gold Gender\"].title()\n",
    "        # Treat bad as wrong\n",
    "        if annotations[key][\"label\"] == \"Translation too bad to tell\":\n",
    "            annotations[key][\"label\"] = \"Male\" if annotations[key][\"Gold Gender\"] == \"female\" else \"Female\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9d8b4f58",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2018722773194922\n"
     ]
    }
   ],
   "source": [
    "# Inter-annotator agreement after data cleaning\n",
    "keys = list(annotations1.keys())\n",
    "labels1 = [annotations1[key][\"label\"] for key in keys]\n",
    "labels2 = [annotations2[key][\"label\"] for key in keys]\n",
    "kappa = cohen_kappa_score(labels1, labels2)\n",
    "print(kappa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "56871514",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Merge annotations\n",
    "annotations = list(itertools.chain(annotations1.values(), annotations2.values()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e4d8d6da",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Load translations\n",
    "with open(winomt_enru_translations_path) as f:\n",
    "    translations = {line.split(\" ||| \")[0].strip(): line.split(\" ||| \")[1].strip() for line in f}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ef31a44b",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-05-13 12:08:05 | INFO | pymorphy2.opencorpora_dict.wrapper | Loading dictionaries from /local/scratch/anonymized/envs/mt_bias/lib/python3.7/site-packages/pymorphy2_dicts_ru/data\n",
      "2021-05-13 12:08:05 | INFO | pymorphy2.opencorpora_dict.wrapper | format: 2.4, revision: 417127, updated: 2020-10-11T15:05:51.070345\n",
      "3888it [00:00, 789210.37it/s]\n"
     ]
    }
   ],
   "source": [
    "# Run classic (translation-based) WinoMT\n",
    "winomt_classic_ende = WinomtClassicTask(\n",
    "    language=Russian,\n",
    "    skip_neutral_gold=False,\n",
    "    caching=False,\n",
    "    verbose=True,\n",
    ")\n",
    "classic_evaluated_samples = winomt_classic_ende.evaluate(DictTranslationModel(translations)).samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dbf9a075",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using cache found in /home/user/anonymized/.cache/torch/hub/pytorch_fairseq_master\n",
      "2021-05-13 12:08:07 | INFO | fairseq.file_utils | loading archive file https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz from cache at /home/user/anonymized/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9\n",
      "2021-05-13 12:08:10 | INFO | fairseq.tasks.translation | [en] dictionary: 31640 types\n",
      "2021-05-13 12:08:10 | INFO | fairseq.tasks.translation | [ru] dictionary: 31232 types\n",
      "2021-05-13 12:08:27 | INFO | fairseq.models.fairseq_model | Namespace(activation_dropout=0.0, activation_fn='relu', adam_betas='(0.9, 0.98)', adam_eps=1e-08, adaptive_input=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, arch='transformer_wmt_en_de_big', attention_dropout=0.1, batch_size=None, bpe='fastbpe', bpe_codes='/home/user/anonymized/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9/bpecodes', bucket_cap_mb=25, clip_norm=0.0, cpu=False, criterion='label_smoothed_cross_entropy', cross_self_attention=False, data='/home/user/anonymized/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9', ddp_backend='c10d', decoder_attention_heads=16, decoder_embed_dim=1024, decoder_embed_path=None, decoder_ffn_embed_dim=4096, decoder_input_dim=1024, decoder_layerdrop=0, decoder_layers=6, decoder_layers_to_keep=None, decoder_learned_pos=False, decoder_normalize_before=False, decoder_output_dim=1024, device_id=0, distributed_backend='nccl', distributed_init_method='tcp://localhost:14352', distributed_port=-1, distributed_rank=0, distributed_world_size=2, dropout=0.2, encoder_attention_heads=16, encoder_embed_dim=1024, encoder_embed_path=None, encoder_ffn_embed_dim=8192, encoder_layerdrop=0, encoder_layers=6, encoder_layers_to_keep=None, encoder_learned_pos=False, encoder_normalize_before=False, eval_bleu_detok='space', eval_bleu_remove_bpe=None, eval_tokenized_bleu=False, extra_data='', fix_batches_to_gpus=False, fp16=True, fp16_init_scale=128, fp16_scale_tolerance=0.0, fp16_scale_window=None, ignore_prefix_size=0, keep_interval_updates=-1, keep_last_epochs=-1, label_smoothing=0.1, layernorm_embedding=False, lazy_load=False, left_pad_source=True, left_pad_target=False, log_format='simple', log_interval=100, lr=[0.0007], lr_scheduler='inverse_sqrt', lr_shrink=0.1, max_epoch=0, max_sentences=None, max_sentences_valid=None, max_source_positions=1024, max_target_positions=1024, max_tokens=3584, max_update=201700, memory_efficient_fp16=False, min_loss_scale=0.0001, min_lr=1e-09, momentum=0.99, moses_no_dash_splits=False, moses_no_escape=False, no_cross_attention=False, no_epoch_checkpoints=False, no_progress_bar=True, no_save=False, no_scale_embedding=False, no_token_positional_embeddings=False, num_batch_buckets=0, num_workers=0, optimizer='adam', optimizer_overrides='{}', quant_noise_pq=0, quant_noise_pq_block_size=8, quant_noise_scalar=0, raw_text=False, relu_dropout=0.0, reset_lr_scheduler=False, reset_optimizer=False, restore_file='checkpoint_last.pt', save_dir='/checkpoint/edunov/20190403/wmt19en2ru.btsample5.commoncrawl.ffn8192.transformer_wmt_en_de_big_bsz3584_lr0.0007_dr0.2_size_updates200000_seed2_lbsm0.1_size_sa0_upsample4/finetune/', save_interval=1, save_interval_updates=200, seed=2, sentence_avg=False, share_all_embeddings=False, share_decoder_input_output_embed=True, skip_invalid_size_inputs_valid_test=False, source_lang='en', target_lang='ru', task='translation', tensorboard_logdir='', threshold_loss_scale=None, tie_adaptive_weights=False, tokenizer='moses', train_subset='train', truncate_source=False, update_freq=[1], upsample_primary=1, use_old_adam=False, user_dir=None, valid_subset='valid', validate_interval=1, warmup_init_lr=1e-07, warmup_updates=4000, weight_decay=0.0)\n"
     ]
    }
   ],
   "source": [
    "# Run source-scoring\n",
    "hub_interface = torch.hub.load(\n",
    "    repo_or_dir='pytorch/fairseq',\n",
    "    model='transformer.wmt19.en-ru',\n",
    "    checkpoint_file=\"model1.pt:model2.pt:model3.pt:model4.pt\",\n",
    "    tokenizer='moses',\n",
    "    bpe='fastbpe',\n",
    ")\n",
    "evaluator_name = 'transformer.wmt19.en-ru.ensemble'\n",
    "evaluator_model = FairseqScoringModel(\n",
    "    name=evaluator_name,\n",
    "    model=hub_interface,\n",
    "    src_bpe_codes=Path(\".\").parent / \"models\" / \"wmt19.en-ru.ffn8192/en24k.fastbpe.code\",\n",
    "    tgt_bpe_codes=Path(\".\").parent / \"models\" / \"wmt19.ru-en.ffn8192/ru24k.fastbpe.code\",\n",
    ")\n",
    "\n",
    "winomt_contrastive_conditioning = WinomtContrastiveConditioningTask(\n",
    "    evaluator_model=evaluator_model,\n",
    "    skip_neutral_gold=False,\n",
    "    caching=False,\n",
    "    category_wise_weighting=True,\n",
    ")\n",
    "contrastive_conditioning_weighted_evaluated_samples = winomt_contrastive_conditioning.evaluate(DictTranslationModel(translations)).samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "09546b32",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Create unweighted contrastive conditioning samples\n",
    "contrastive_conditioning_unweighted_evaluated_samples = deepcopy(contrastive_conditioning_weighted_evaluated_samples)\n",
    "for sample in contrastive_conditioning_unweighted_evaluated_samples:\n",
    "    sample.weight = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "4fb65075",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "        male      0.824     0.847     0.836     321.0\n",
      "      female      0.538     0.496     0.516     115.0\n",
      "\n",
      "    accuracy                          0.755     436.0\n",
      "   macro avg      0.681     0.672     0.676     436.0\n",
      "weighted avg      0.749     0.755     0.751     436.0\n",
      "\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "        male      0.769     0.900     0.829     321.0\n",
      "      female      0.467     0.243     0.320     115.0\n",
      "\n",
      "    accuracy                          0.727     436.0\n",
      "   macro avg      0.618     0.572     0.575     436.0\n",
      "weighted avg      0.689     0.727     0.695     436.0\n",
      "\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "        male      0.801     0.960     0.873  297705.0\n",
      "      female      0.522     0.155     0.239   84127.0\n",
      "\n",
      "    accuracy                          0.782  381832.0\n",
      "   macro avg      0.661     0.558     0.556  381832.0\n",
      "weighted avg      0.739     0.782     0.733  381832.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Evaluate\n",
    "for evaluated_samples in [\n",
    "    classic_evaluated_samples,\n",
    "    contrastive_conditioning_unweighted_evaluated_samples,\n",
    "    contrastive_conditioning_weighted_evaluated_samples,\n",
    "]:\n",
    "    predicted_labels = []\n",
    "    gold_labels = []\n",
    "    weights = []\n",
    "    for annotation in annotations:\n",
    "        gold_labels.append(WB_GENDER_TYPES[annotation[\"label\"].lower()].value)\n",
    "        sample_index = int(annotation[\"Index\"])\n",
    "        evaluated_sample = evaluated_samples[sample_index]\n",
    "        assert evaluated_sample.sentence == annotation[\"Source Sentence\"]\n",
    "        if hasattr(evaluated_sample, \"predicted_gender\"):\n",
    "            predicted_gender = evaluated_sample.predicted_gender.value\n",
    "            # Convert neutral or unknown to gold in order to treat classic WinoMT as fairly as possible\n",
    "            if predicted_gender in {GENDER.neutral.value, GENDER.unknown.value}:\n",
    "                predicted_gender = evaluated_sample.gold_gender.value\n",
    "        else:\n",
    "            if evaluated_sample.is_correct:\n",
    "                predicted_gender = WB_GENDER_TYPES[evaluated_sample.gold_gender].value\n",
    "            else:\n",
    "                predicted_gender = int(not WB_GENDER_TYPES[evaluated_sample.gold_gender].value)\n",
    "        predicted_labels.append(predicted_gender)\n",
    "        weights.append(getattr(evaluated_sample, \"weight\", 1))\n",
    "    class_labels = [gender.value for gender in GENDER][:2]\n",
    "    target_names = [gender.name for gender in GENDER][:2]\n",
    "    print(classification_report(\n",
    "        y_true=gold_labels,\n",
    "        y_pred=predicted_labels,\n",
    "        labels=class_labels,\n",
    "        target_names=target_names,\n",
    "        sample_weight=weights,\n",
    "        zero_division=True,\n",
    "        digits=3,\n",
    "    ))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a55f974",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
