"""Sees if the labels for lm_mcqa have a unique first token."""

from absl import app
from absl import flags

import torch

from npeff_torch.util import tokenizer_utils

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string('tokenizer', None, '')

flags.DEFINE_list('answer_labels', None, '')
flags.DEFINE_string('answer_label_prefix', '', '')
      
###############################################################################


@torch.no_grad()
def main(_):
    tokenizer = tokenizer_utils.from_pretrained(FLAGS.tokenizer)

    answer_labels = FLAGS.answer_labels
    answer_label_prefix = FLAGS.answer_label_prefix

    token_ids = tuple(
        tokenizer.encode(answer_label_prefix + answer_label)[0]
        for answer_label in answer_labels
    )
    if len(set(token_ids)) != len(token_ids):
        raise ValueError('The answer labels must have a unique first token when tokenized.')
    else:
        print('All OK.')


if __name__ == "__main__":
    app.run(main)
