This is the Pytorch implementation of paper "Rethinking soft labels for knowledge distillation: a bias-variance tradeoff perspective".

Requirements:
pytorch >= 1.0.1
python >= 3.6

The code is used for MNLI dataset of nlp. Our codes are based on TextBrewer, a open source toolkit for nlp distillation (https://github.com/airaria/TextBrewer), Thanks for their contribution.

* example/mnli_example/run_mnli_baseline.sh : train a baseline model (bert_T3) on MNLI directly.
* example/mnli_example/run_mnli_teacher.sh : trains a teacher model (bert-base-cased) on MNLI.
* example/mnli_example/run_mnli_distill_T3.sh : distills the teacher to T3 with our method.
* wsl_kd_loss in src/textbrewer/losses.py: our distillation loss

Set the following variables in the shell scripts before running:

* BERT_DIR : where BERT-base-cased stores，including vocab.txt, pytorch_model.bin, bert_config.json
* OUTPUT_ROOT_DIR : this directory stores logs and trained model weights
* DATA_ROOT_DIR : it includes MNLI dataset:
  * \$\{DATA_ROOT_DIR\}/MNLI/train.tsv
  * \$\{DATA_ROOT_DIR\}/MNLI/dev_matched.tsv
  * \$\{DATA_ROOT_DIR\}/MNLI/dev_mismatched.tsv
* The trained teacher weights file *trained_teacher_model* has to be specified if running run_mnli_distill_T3.sh

